From dec8194eb2bf227027e67cf0a158c9aa9171c691 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Mon, 1 Apr 2019 18:53:11 -0400 Subject: [PATCH 01/14] Templatized some of the core files: - pattern - rule - ruleset - group - groupexpression - binding - memo - optimize_context - optimizer_task (TopDownRewrite/BottomUpRewrite) Templates generally followed: template The template instantiation associated with: Node = Operator, OperatorType = OpType, OperatorExpr = OperatorExpression is used primarily by the core Optimizer. All references to the templated files/classes from core optimizer files were instantiated to that. Note worth mentioning: Operator class defines a public interface wrapper around BaseOperatorNode, basically defines a single logical/physical operator. OpType class defines the various logical/physical operations OperatorExpression class is essentially a tree of Operator --- src/include/optimizer/binding.h | 45 ++-- .../optimizer/child_property_deriver.h | 12 +- .../cost_model/abstract_cost_model.h | 5 +- .../optimizer/cost_model/default_cost_model.h | 11 +- .../cost_model/postgres_cost_model.h | 11 +- .../optimizer/cost_model/trivial_cost_model.h | 11 +- src/include/optimizer/group.h | 22 +- src/include/optimizer/group_expression.h | 20 +- src/include/optimizer/input_column_deriver.h | 11 +- src/include/optimizer/memo.h | 36 +-- src/include/optimizer/optimize_context.h | 6 +- src/include/optimizer/optimizer.h | 23 +- src/include/optimizer/optimizer_metadata.h | 30 ++- src/include/optimizer/optimizer_task.h | 111 +++++--- src/include/optimizer/optimizer_task_pool.h | 17 +- src/include/optimizer/pattern.h | 10 +- src/include/optimizer/property_enforcer.h | 8 +- src/include/optimizer/rule.h | 61 +++-- src/include/optimizer/rule_impls.h | 172 ++++++------ .../optimizer/stats/child_stats_deriver.h | 12 +- .../optimizer/stats/stats_calculator.h | 12 +- src/optimizer/binding.cpp | 65 +++-- src/optimizer/child_property_deriver.cpp | 6 +- src/optimizer/group.cpp | 47 +++- src/optimizer/group_expression.cpp | 45 +++- src/optimizer/input_column_deriver.cpp | 5 +- src/optimizer/memo.cpp | 116 +++++--- src/optimizer/optimizer.cpp | 32 +-- src/optimizer/optimizer_task.cpp | 117 ++++---- src/optimizer/pattern.cpp | 15 +- src/optimizer/property_enforcer.cpp | 10 +- src/optimizer/rule.cpp | 28 +- src/optimizer/rule_impls.cpp | 250 +++++++++--------- src/optimizer/stats/child_stats_deriver.cpp | 4 +- src/optimizer/stats/stats_calculator.cpp | 4 +- test/include/optimizer/mock_task.h | 4 +- test/optimizer/optimizer_rule_test.cpp | 8 +- test/optimizer/optimizer_test.cpp | 38 +-- 38 files changed, 836 insertions(+), 604 deletions(-) diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index 7a6d772813d..616bda57782 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -24,63 +24,68 @@ namespace peloton { namespace optimizer { class Optimizer; + +template class Memo; //===--------------------------------------------------------------------===// // Binding Iterator //===--------------------------------------------------------------------===// +template class BindingIterator { public: - BindingIterator(Memo& memo) : memo_(memo) {} + BindingIterator(Memo& memo) : memo_(memo) {} virtual ~BindingIterator(){}; virtual bool HasNext() = 0; - virtual std::shared_ptr Next() = 0; + virtual std::shared_ptr Next() = 0; protected: - Memo &memo_; + Memo &memo_; }; -class GroupBindingIterator : public BindingIterator { +template +class GroupBindingIterator : public BindingIterator { public: - GroupBindingIterator(Memo& memo, GroupID id, - std::shared_ptr pattern); + GroupBindingIterator(Memo& memo, + GroupID id, + std::shared_ptr> pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupID group_id_; - std::shared_ptr pattern_; - Group *target_group_; + std::shared_ptr> pattern_; + Group *target_group_; size_t num_group_items_; size_t current_item_index_; - std::unique_ptr current_iterator_; + std::unique_ptr> current_iterator_; }; -class GroupExprBindingIterator : public BindingIterator { +template +class GroupExprBindingIterator : public BindingIterator { public: - GroupExprBindingIterator(Memo& memo, - GroupExpression *gexpr, - std::shared_ptr pattern); + GroupExprBindingIterator(Memo& memo, + GroupExpression *gexpr, + std::shared_ptr> pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: - GroupExpression* gexpr_; - std::shared_ptr pattern_; + GroupExpression* gexpr_; + std::shared_ptr> pattern_; bool first_; bool has_next_; - std::shared_ptr current_binding_; - std::vector>> - children_bindings_; + std::shared_ptr current_binding_; + std::vector>> children_bindings_; std::vector children_bindings_pos_; }; diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index 914cc77ab27..6ec2c09400a 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -13,10 +13,12 @@ #pragma once #include #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { namespace optimizer { +template class Memo; } @@ -33,8 +35,10 @@ class ChildPropertyDeriver : public OperatorVisitor { public: std::vector, std::vector>>> - GetProperties(GroupExpression *gexpr, - std::shared_ptr requirements, Memo *memo); + + GetProperties(GroupExpression *gexpr, + std::shared_ptr requirements, + Memo *memo); void Visit(const DummyScan *) override; void Visit(const PhysicalSeqScan *) override; @@ -74,8 +78,8 @@ class ChildPropertyDeriver : public OperatorVisitor { * @brief We need the memo and gexpr because some property may depend on * child's schema */ - Memo *memo_; - GroupExpression *gexpr_; + Memo *memo_; + GroupExpression *gexpr_; }; } // namespace optimizer diff --git a/src/include/optimizer/cost_model/abstract_cost_model.h b/src/include/optimizer/cost_model/abstract_cost_model.h index 95a593f04d9..e01548739b1 100644 --- a/src/include/optimizer/cost_model/abstract_cost_model.h +++ b/src/include/optimizer/cost_model/abstract_cost_model.h @@ -13,10 +13,12 @@ #pragma once #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { namespace optimizer { +template class Memo; // Default cost when cost model cannot compute correct cost. @@ -34,7 +36,8 @@ static constexpr double DEFAULT_OPERATOR_COST = 0.0025; class AbstractCostModel : public OperatorVisitor { public: - virtual double CalculateCost(GroupExpression *gexpr, Memo *memo, + virtual double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) = 0; }; diff --git a/src/include/optimizer/cost_model/default_cost_model.h b/src/include/optimizer/cost_model/default_cost_model.h index a92cb091db7..a89bd4ee3a3 100644 --- a/src/include/optimizer/cost_model/default_cost_model.h +++ b/src/include/optimizer/cost_model/default_cost_model.h @@ -23,14 +23,17 @@ namespace peloton { namespace optimizer { +template class Memo; + // Derive cost for a physical group expression class DefaultCostModel : public AbstractCostModel { public: DefaultCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, - concurrency::TransactionContext *txn) { + double CalculateCost(GroupExpression *gexpr, + Memo *memo, + concurrency::TransactionContext *txn) { gexpr_ = gexpr; memo_ = memo; txn_ = txn; @@ -151,8 +154,8 @@ class DefaultCostModel : public AbstractCostModel { return child_num_rows * DEFAULT_TUPLE_COST; } - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; }; diff --git a/src/include/optimizer/cost_model/postgres_cost_model.h b/src/include/optimizer/cost_model/postgres_cost_model.h index 2632a247a39..523983a89d1 100644 --- a/src/include/optimizer/cost_model/postgres_cost_model.h +++ b/src/include/optimizer/cost_model/postgres_cost_model.h @@ -28,13 +28,16 @@ namespace peloton { namespace optimizer { +template class Memo; + // Derive cost for a physical group expression class PostgresCostModel : public AbstractCostModel { public: PostgresCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, + double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) override { gexpr_ = gexpr; memo_ = memo; @@ -230,8 +233,8 @@ class PostgresCostModel : public AbstractCostModel { } - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; @@ -279,4 +282,4 @@ class PostgresCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/cost_model/trivial_cost_model.h b/src/include/optimizer/cost_model/trivial_cost_model.h index 2c5994ee728..f755626f083 100644 --- a/src/include/optimizer/cost_model/trivial_cost_model.h +++ b/src/include/optimizer/cost_model/trivial_cost_model.h @@ -31,12 +31,15 @@ namespace peloton { namespace optimizer { +template class Memo; + class TrivialCostModel : public AbstractCostModel { public: TrivialCostModel(){}; - double CalculateCost(GroupExpression *gexpr, Memo *memo, + double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) override { gexpr_ = gexpr; memo_ = memo; @@ -109,11 +112,11 @@ class TrivialCostModel : public AbstractCostModel { } private: - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; concurrency::TransactionContext *txn_; double output_cost_ = 0; }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/group.h b/src/include/optimizer/group.h index a0606d1597c..9129a4952a8 100644 --- a/src/include/optimizer/group.h +++ b/src/include/optimizer/group.h @@ -32,6 +32,7 @@ class ColumnStats; //===--------------------------------------------------------------------===// // Group //===--------------------------------------------------------------------===// +template class Group : public Printable { public: Group(GroupID id, std::unordered_set table_alias); @@ -39,29 +40,30 @@ class Group : public Printable { // If the GroupExpression is generated by applying a // property enforcer, we add them to enforced_exprs_ // which will not be enumerated during OptimizeExpression - void AddExpression(std::shared_ptr expr, bool enforced); + void AddExpression(std::shared_ptr> expr, + bool enforced); void RemoveLogicalExpression(size_t idx) { logical_expressions_.erase(logical_expressions_.begin() + idx); } - bool SetExpressionCost(GroupExpression *expr, double cost, + bool SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties); - GroupExpression *GetBestExpression(std::shared_ptr &properties); + GroupExpression *GetBestExpression(std::shared_ptr &properties); inline const std::unordered_set &GetTableAliases() const { return table_aliases_; } // TODO: thread safety? - const std::vector> GetLogicalExpressions() + const std::vector>> GetLogicalExpressions() const { return logical_expressions_; } // TODO: thread safety? - const std::vector> GetPhysicalExpressions() + const std::vector>> GetPhysicalExpressions() const { return physical_expressions_; } @@ -105,7 +107,7 @@ class Group : public Printable { // This should only be called in rewrite phase to retrieve the only logical // expr in the group - inline GroupExpression *GetLogicalExpression() { + inline GroupExpression *GetLogicalExpression() { PELOTON_ASSERT(logical_expressions_.size() == 1); PELOTON_ASSERT(physical_expressions_.size() == 0); return logical_expressions_[0].get(); @@ -117,15 +119,15 @@ class Group : public Printable { // TODO(boweic) Do not use string, store table alias id std::unordered_set table_aliases_; std::unordered_map, - std::tuple, PropSetPtrHash, + std::tuple *>, PropSetPtrHash, PropSetPtrEq> lowest_cost_expressions_; // Whether equivalent logical expressions have been explored for this group bool has_explored_; - std::vector> logical_expressions_; - std::vector> physical_expressions_; - std::vector> enforced_exprs_; + std::vector>> logical_expressions_; + std::vector>> physical_expressions_; + std::vector>> enforced_exprs_; // We'll add stats lazily // TODO(boweic): diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index 303ebaf036e..af71c9e75e2 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -25,6 +25,7 @@ namespace peloton { namespace optimizer { +template class Rule; using GroupID = int32_t; @@ -32,9 +33,10 @@ using GroupID = int32_t; //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// +template class GroupExpression { public: - GroupExpression(Operator op, std::vector child_groups); + GroupExpression(Node op, std::vector child_groups); GroupID GetGroupID() const; @@ -46,7 +48,7 @@ class GroupExpression { GroupID GetChildGroupId(int child_idx) const; - Operator Op() const; + Node Op() const; double GetCost(std::shared_ptr& requirements) const; @@ -61,11 +63,11 @@ class GroupExpression { hash_t Hash() const; - bool operator==(const GroupExpression &r); + bool operator==(const GroupExpression &r); - void SetRuleExplored(Rule *rule); + void SetRuleExplored(Rule *rule); - bool HasRuleExplored(Rule *rule); + bool HasRuleExplored(Rule *rule); void SetDerivedStats() { stats_derived_ = true; } @@ -75,7 +77,7 @@ class GroupExpression { private: GroupID group_id; - Operator op; + Node op; std::vector child_groups; std::bitset(RuleType::NUM_RULES)> rule_mask_; bool stats_derived_; @@ -92,9 +94,9 @@ class GroupExpression { namespace std { -template <> -struct hash { - typedef peloton::optimizer::GroupExpression argument_type; +template +struct hash> { + typedef peloton::optimizer::GroupExpression argument_type; typedef std::size_t result_type; result_type operator()(argument_type const &s) const { return s.Hash(); } }; diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index ef66823bba0..dd368f8636f 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -27,6 +27,8 @@ class AggregatePlan; namespace optimizer { class OperatorExpression; + +template class Memo; } @@ -44,8 +46,9 @@ class InputColumnDeriver : public OperatorVisitor { std::pair, std::vector>> DeriveInputColumns( - GroupExpression *gexpr, std::shared_ptr properties, - std::vector required_cols, Memo *memo); + GroupExpression *gexpr, std::shared_ptr properties, + std::vector required_cols, + Memo *memo); void Visit(const DummyScan *) override; @@ -108,8 +111,8 @@ class InputColumnDeriver : public OperatorVisitor { * property */ void Passdown(); - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; /** * @brief The derived output columns and input columns, note that the current diff --git a/src/include/optimizer/memo.h b/src/include/optimizer/memo.h index 951caa4c94d..be67f961c9a 100644 --- a/src/include/optimizer/memo.h +++ b/src/include/optimizer/memo.h @@ -22,13 +22,15 @@ namespace peloton { namespace optimizer { +template struct GExprPtrHash { - std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } + std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } }; +template struct GExprPtrEq { - bool operator()(GroupExpression* const& t1, - GroupExpression* const& t2) const { + bool operator()(GroupExpression* const& t1, + GroupExpression* const& t2) const { return *t1 == *t2; } }; @@ -36,6 +38,7 @@ struct GExprPtrEq { //===--------------------------------------------------------------------===// // Memo //===--------------------------------------------------------------------===// +template class Memo { public: Memo(); @@ -48,15 +51,17 @@ class Memo { * target_group: an optional target group to insert expression into * return: existing expression if found. Otherwise, return the new expr */ - GroupExpression* InsertExpression(std::shared_ptr gexpr, - bool enforced); + GroupExpression* InsertExpression( + std::shared_ptr> gexpr, + bool enforced); - GroupExpression* InsertExpression(std::shared_ptr gexpr, - GroupID target_group, bool enforced); + GroupExpression* InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced); - std::vector>& Groups(); + std::vector>>& Groups(); - Group* GetGroupByID(GroupID id); + Group* GetGroupByID(GroupID id); const std::string GetInfo(int num_indent) const; const std::string GetInfo() const; @@ -68,10 +73,10 @@ class Memo { //===--------------------------------------------------------------------===// // For rewrite phase: remove and add expression directly for the set //===--------------------------------------------------------------------===// - void RemoveParExpressionForRewirte(GroupExpression* gexpr) { + void RemoveParExpressionForRewirte(GroupExpression* gexpr) { group_expressions_.erase(gexpr); } - void AddParExpressionForRewrite(GroupExpression* gexpr) { + void AddParExpressionForRewrite(GroupExpression* gexpr) { group_expressions_.insert(gexpr); } // When a rewrite rule is applied, we need to replace the original gexpr with @@ -84,12 +89,13 @@ class Memo { } private: - GroupID AddNewGroup(std::shared_ptr gexpr); + GroupID AddNewGroup(std::shared_ptr> gexpr); // The group owns the group expressions, not the memo - std::unordered_set - group_expressions_; - std::vector> groups_; + std::unordered_set*, + GExprPtrHash, + GExprPtrEq> group_expressions_; + std::vector>> groups_; size_t rule_set_size_; }; diff --git a/src/include/optimizer/optimize_context.h b/src/include/optimizer/optimize_context.h index b5568208d9e..15747a44b5a 100644 --- a/src/include/optimizer/optimize_context.h +++ b/src/include/optimizer/optimize_context.h @@ -22,18 +22,20 @@ namespace peloton { namespace optimizer { +template class OptimizerMetadata; +template class OptimizeContext { public: - OptimizeContext(OptimizerMetadata *metadata, + OptimizeContext(OptimizerMetadata *metadata, std::shared_ptr required_prop, double cost_upper_bound = std::numeric_limits::max()) : metadata(metadata), required_prop(required_prop), cost_upper_bound(cost_upper_bound) {} - OptimizerMetadata *metadata; + OptimizerMetadata *metadata; std::shared_ptr required_prop; double cost_upper_bound; }; diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index ebf82d625b4..668049b5333 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -60,7 +60,10 @@ enum CostModels {DEFAULT, POSTGRES, TRIVIAL}; // Optimizer //===--------------------------------------------------------------------===// class Optimizer : public AbstractOptimizer { + template friend class BindingIterator; + + template friend class GroupBindingIterator; friend class ::peloton::test:: @@ -85,16 +88,18 @@ class Optimizer : public AbstractOptimizer { void Reset() override; - OptimizerMetadata &GetMetadata() { return metadata_; } + OptimizerMetadata &GetMetadata() { return metadata_; } /* For test purposes only */ - std::shared_ptr TestInsertQueryTree( - parser::SQLStatement *tree, concurrency::TransactionContext *txn) { + std::shared_ptr> TestInsertQueryTree( + parser::SQLStatement *tree, + concurrency::TransactionContext *txn) { + return InsertQueryTree(tree, txn); } /* For test purposes only */ - void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context) { + void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context) { return ExecuteTaskStack(task_stack, root_group_id, root_context); } @@ -119,7 +124,7 @@ class Optimizer : public AbstractOptimizer { * tree: a peloton query tree representing a select query * return: the root group expression for the inserted query */ - std::shared_ptr InsertQueryTree( + std::shared_ptr> InsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn); /* GetQueryTreeRequiredProperties - get the required physical properties for @@ -161,12 +166,12 @@ class Optimizer : public AbstractOptimizer { * root_context: the OptimizerContext to use that maintains required *properties */ - void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context); + void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context); ////////////////////////////////////////////////////////////////////////////// /// Metadata - OptimizerMetadata metadata_; + OptimizerMetadata metadata_; std::unique_ptr cost_model_; }; diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 3f33e3ee8b1..84a6977ee09 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -26,9 +26,13 @@ class CatalogCache; } namespace optimizer { +template class OptimizerTaskPool; + +template class RuleSet; +template class OptimizerMetadata { public: @@ -37,45 +41,45 @@ class OptimizerMetadata { settings::SettingId::task_execution_timeout)), timer(Timer()) {} - Memo memo; - RuleSet rule_set; - OptimizerTaskPool *task_pool; + Memo memo; + RuleSet rule_set; + OptimizerTaskPool *task_pool; std::unique_ptr cost_model; catalog::CatalogCache *catalog_cache; unsigned int timeout_limit; Timer timer; concurrency::TransactionContext* txn; - void SetTaskPool(OptimizerTaskPool *task_pool) { + void SetTaskPool(OptimizerTaskPool *task_pool) { this->task_pool = task_pool; } - std::shared_ptr MakeGroupExpression( - std::shared_ptr expr) { + std::shared_ptr> MakeGroupExpression( + std::shared_ptr expr) { std::vector child_groups; for (auto &child : expr->Children()) { auto gexpr = MakeGroupExpression(child); memo.InsertExpression(gexpr, false); child_groups.push_back(gexpr->GetGroupID()); } - return std::make_shared(expr->Op(), - std::move(child_groups)); + return std::make_shared>(expr->Op(), + std::move(child_groups)); } - bool RecordTransformedExpression(std::shared_ptr expr, - std::shared_ptr &gexpr) { + bool RecordTransformedExpression(std::shared_ptr expr, + std::shared_ptr> &gexpr) { return RecordTransformedExpression(expr, gexpr, UNDEFINED_GROUP); } - bool RecordTransformedExpression(std::shared_ptr expr, - std::shared_ptr &gexpr, + bool RecordTransformedExpression(std::shared_ptr expr, + std::shared_ptr> &gexpr, GroupID target_group) { gexpr = MakeGroupExpression(expr); return (memo.InsertExpression(gexpr, target_group, false) == gexpr.get()); } // TODO(boweic): check if we really need to use shared_ptr - void ReplaceRewritedExpression(std::shared_ptr expr, + void ReplaceRewritedExpression(std::shared_ptr expr, GroupID target_group) { memo.EraseExpression(target_group); memo.InsertExpression(MakeGroupExpression(expr), target_group, false); diff --git a/src/include/optimizer/optimizer_task.h b/src/include/optimizer/optimizer_task.h index fb2edeaa5db..173c64075c6 100644 --- a/src/include/optimizer/optimizer_task.h +++ b/src/include/optimizer/optimizer_task.h @@ -24,14 +24,33 @@ class AbstractExpression; } namespace optimizer { +template class OptimizeContext; + +template class Memo; + +template class Rule; + +template struct RuleWithPromise; + +template class RuleSet; + +template class Group; + +template class GroupExpression; + +template class OptimizerMetadata; + +enum class OpType; +class Operator; +class OperatorExpression; class PropertySet; enum class RewriteRuleSetName : uint32_t; using GroupID = int32_t; @@ -53,9 +72,10 @@ enum class OptimizerTaskType { /** * @brief The base class for tasks in the optimizer */ +template class OptimizerTask { public: - OptimizerTask(std::shared_ptr context, + OptimizerTask(std::shared_ptr> context, OptimizerTaskType type) : type_(type), context_(context) {} @@ -71,24 +91,24 @@ class OptimizerTask { * @param valid_rules The valid rules to apply in the current rule set will be * append to valid_rules, with their promises */ - static void ConstructValidRules(GroupExpression *group_expr, - OptimizeContext *context, - std::vector> &rules, - std::vector &valid_rules); + static void ConstructValidRules(GroupExpression *group_expr, + OptimizeContext *context, + std::vector>> &rules, + std::vector> &valid_rules); virtual void execute() = 0; - void PushTask(OptimizerTask *task); + void PushTask(OptimizerTask *task); - inline Memo &GetMemo() const; + inline Memo &GetMemo() const; - inline RuleSet &GetRuleSet() const; + inline RuleSet &GetRuleSet() const; virtual ~OptimizerTask(){}; protected: OptimizerTaskType type_; - std::shared_ptr context_; + std::shared_ptr> context_; }; /** @@ -96,15 +116,16 @@ class OptimizerTask { * equivalent operator trees if not already explored 2. Cost all physical * operator trees given the current context */ -class OptimizeGroup : public OptimizerTask { +class OptimizeGroup : public OptimizerTask { public: - OptimizeGroup(Group *group, std::shared_ptr context) + OptimizeGroup(Group *group, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -114,31 +135,32 @@ class OptimizeGroup : public OptimizerTask { * promises so that a physical transformation rule is applied before a logical * transformation rule */ -class OptimizeExpression : public OptimizerTask { +class OptimizeExpression : public OptimizerTask { public: - OptimizeExpression(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeExpression(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** * @brief Generate all logical transformation rules by applying logical * transformation rules to logical operators in the group until saturated */ -class ExploreGroup : public OptimizerTask { +class ExploreGroup : public OptimizerTask { public: - ExploreGroup(Group *group, std::shared_ptr context) + ExploreGroup(Group *group, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -146,16 +168,16 @@ class ExploreGroup : public OptimizerTask { * pattern * in the same group is found, also apply logical transformation rule for it. */ -class ExploreExpression : public OptimizerTask { +class ExploreExpression : public OptimizerTask { public: - ExploreExpression(GroupExpression *group_expr, - std::shared_ptr context) + ExploreExpression(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** @@ -164,10 +186,11 @@ class ExploreExpression : public OptimizerTask { * to the new group expression based on the explore flag. If the rule is a * physical implementation rule, we directly cost the physical expression */ -class ApplyRule : public OptimizerTask { +class ApplyRule : public OptimizerTask { public: - ApplyRule(GroupExpression *group_expr, Rule *rule, - std::shared_ptr context, bool explore = false) + ApplyRule(GroupExpression *group_expr, + Rule *rule, + std::shared_ptr> context, bool explore = false) : OptimizerTask(context, OptimizerTaskType::APPLY_RULE), group_expr_(group_expr), rule_(rule), @@ -175,8 +198,8 @@ class ApplyRule : public OptimizerTask { virtual void execute() override; private: - GroupExpression *group_expr_; - Rule *rule_; + GroupExpression *group_expr_; + Rule *rule_; bool explore_only; }; @@ -187,10 +210,10 @@ class ApplyRule : public OptimizerTask { * current expression's cost is larger than the upper bound of the current * group */ -class OptimizeInputs : public OptimizerTask { +class OptimizeInputs : public OptimizerTask { public: - OptimizeInputs(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeInputs(GroupExpression *group_expr, + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_INPUTS), group_expr_(group_expr) {} @@ -208,7 +231,7 @@ class OptimizeInputs : public OptimizerTask { std::vector, std::vector>>> output_input_properties_; - GroupExpression *group_expr_; + GroupExpression *group_expr_; double cur_total_cost_; int cur_child_idx_ = -1; int prev_child_idx_ = -1; @@ -220,11 +243,11 @@ class OptimizeInputs : public OptimizerTask { * child group have the stats, if not, recursively derive the stats. This would * lazily collect the stats for the column needed */ -class DeriveStats : public OptimizerTask { +class DeriveStats : public OptimizerTask { public: - DeriveStats(GroupExpression *gexpr, + DeriveStats(GroupExpression *gexpr, ExprSet required_cols, - std::shared_ptr context) + std::shared_ptr> context) : OptimizerTask(context, OptimizerTaskType::DERIVE_STATS), gexpr_(gexpr), required_cols_(required_cols) {} @@ -237,7 +260,7 @@ class DeriveStats : public OptimizerTask { virtual void execute() override; private: - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; }; @@ -247,11 +270,13 @@ class DeriveStats : public OptimizerTask { * level rewrite. An example is predicate push-down. We only push the predicates * from the upper level to the lower level. */ -class TopDownRewrite : public OptimizerTask { +template +class TopDownRewrite : public OptimizerTask { public: - TopDownRewrite(GroupID group_id, std::shared_ptr context, + TopDownRewrite(GroupID group_id, + std::shared_ptr> context, RewriteRuleSetName rule_set_name) - : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), + : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), group_id_(group_id), rule_set_name_(rule_set_name) {} virtual void execute() override; @@ -266,11 +291,13 @@ class TopDownRewrite : public OptimizerTask { * that the upper level rewrite in the operator tree will not enable lower * level rewrite. */ -class BottomUpRewrite : public OptimizerTask { +template +class BottomUpRewrite : public OptimizerTask { public: - BottomUpRewrite(GroupID group_id, std::shared_ptr context, + BottomUpRewrite(GroupID group_id, + std::shared_ptr> context, RewriteRuleSetName rule_set_name, bool has_optimized_child) - : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), + : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), group_id_(group_id), rule_set_name_(rule_set_name), has_optimized_child_(has_optimized_child) {} diff --git a/src/include/optimizer/optimizer_task_pool.h b/src/include/optimizer/optimizer_task_pool.h index a14789df64a..2ce755e8de0 100644 --- a/src/include/optimizer/optimizer_task_pool.h +++ b/src/include/optimizer/optimizer_task_pool.h @@ -24,32 +24,35 @@ namespace optimizer { * is identical to a stack but we may need to implement a different data * structure for multi-threaded optimization */ + +template class OptimizerTaskPool { public: - virtual std::unique_ptr Pop() = 0; - virtual void Push(OptimizerTask *task) = 0; + virtual std::unique_ptr> Pop() = 0; + virtual void Push(OptimizerTask *task) = 0; virtual bool Empty() = 0; }; /** * @brief Stack implementation of the task pool */ -class OptimizerTaskStack : public OptimizerTaskPool { +template +class OptimizerTaskStack : public OptimizerTaskPool { public: - virtual std::unique_ptr Pop() { + virtual std::unique_ptr> Pop() { auto task = std::move(task_stack_.top()); task_stack_.pop(); return task; } - virtual void Push(OptimizerTask *task) { - task_stack_.push(std::unique_ptr(task)); + virtual void Push(OptimizerTask *task) { + task_stack_.push(std::unique_ptr>(task)); } virtual bool Empty() { return task_stack_.empty(); } private: - std::stack> task_stack_; + std::stack>> task_stack_; }; } // namespace optimizer diff --git a/src/include/optimizer/pattern.h b/src/include/optimizer/pattern.h index 67c52592889..176fb382b9a 100644 --- a/src/include/optimizer/pattern.h +++ b/src/include/optimizer/pattern.h @@ -20,9 +20,13 @@ namespace peloton { namespace optimizer { +/** + * template parameter should *really* only be OpType or ExpressionType + */ +template class Pattern { public: - Pattern(OpType op); + Pattern(OperatorType op); void AddChild(std::shared_ptr child); @@ -30,10 +34,10 @@ class Pattern { inline size_t GetChildPatternsSize() const { return children.size(); } - OpType Type() const; + OperatorType Type() const; private: - OpType _type; + OperatorType _type; std::vector> children; }; diff --git a/src/include/optimizer/property_enforcer.h b/src/include/optimizer/property_enforcer.h index e82b802d84c..c826edbe54d 100644 --- a/src/include/optimizer/property_enforcer.h +++ b/src/include/optimizer/property_enforcer.h @@ -30,8 +30,8 @@ class PropertyEnforcer : public PropertyVisitor { public: - std::shared_ptr EnforceProperty( - GroupExpression* gexpr, Property* property); + std::shared_ptr> EnforceProperty( + GroupExpression* gexpr, Property* property); virtual void Visit(const PropertyColumns *) override; virtual void Visit(const PropertySort *) override; @@ -39,8 +39,8 @@ class PropertyEnforcer : public PropertyVisitor { virtual void Visit(const PropertyLimit *) override; private: - GroupExpression* input_gexpr_; - std::shared_ptr output_gexpr_; + GroupExpression* input_gexpr_; + std::shared_ptr> output_gexpr_; }; } // namespace optimizer diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 4ea78a630c6..b6f85a4c085 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -21,19 +21,18 @@ namespace peloton { namespace optimizer { +template class GroupExpression; #define PHYS_PROMISE 3 #define LOG_PROMISE 1 -/** - * @brief The base class of all rules - */ +template class Rule { public: virtual ~Rule(){}; - std::shared_ptr GetMatchPattern() const { return match_pattern; } + std::shared_ptr> GetMatchPattern() const { return match_pattern; } bool IsPhysical() const { return type_ > RuleType::LogicalPhysicalDelimiter && @@ -58,8 +57,8 @@ class Rule { * @return The promise, the higher the promise, the rule should be applied * sooner */ - virtual int Promise(GroupExpression *group_expr, - OptimizeContext *context) const; + virtual int Promise(GroupExpression *group_expr, + OptimizeContext *context) const; /** * @brief Check if the rule is applicable for the operator expression. The @@ -74,8 +73,8 @@ class Rule { * * @return If the rule is applicable, return true, otherwise return false */ - virtual bool Check(std::shared_ptr expr, - OptimizeContext *context) const = 0; + virtual bool Check(std::shared_ptr expr, + OptimizeContext *context) const = 0; /** * @brief Convert a "before" operator tree to an "after" operator tree @@ -85,30 +84,31 @@ class Rule { * @param context The current optimization context */ virtual void Transform( - std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const = 0; + std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const = 0; inline RuleType GetType() { return type_; } inline uint32_t GetRuleIdx() { return static_cast(type_); } protected: - std::shared_ptr match_pattern; + std::shared_ptr> match_pattern; RuleType type_; }; /** * @brief A struct to store a rule together with its promise */ +template struct RuleWithPromise { - RuleWithPromise(Rule *rule, int promise) : rule(rule), promise(promise) {} + RuleWithPromise(Rule *rule, int promise) : rule(rule), promise(promise) {} - Rule *rule; + Rule *rule; int promise; - bool operator<(const RuleWithPromise &r) const { return promise < r.promise; } - bool operator>(const RuleWithPromise &r) const { return promise > r.promise; } + bool operator<(const RuleWithPromise &r) const { return promise < r.promise; } + bool operator>(const RuleWithPromise &r) const { return promise > r.promise; } }; enum class RewriteRuleSetName : uint32_t { @@ -120,41 +120,46 @@ enum class RewriteRuleSetName : uint32_t { * @brief All the rule sets, including logical transformation rules, physical * implementation rules and rewrite rules */ +template class RuleSet { public: // RuleSet will take the ownership of the rule object RuleSet(); - inline void AddTransformationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } + inline void AddTransformationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } - inline void AddImplementationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } + inline void AddImplementationRule(Rule* rule) { transformation_rules_.emplace_back(rule); } - inline void AddRewriteRule(RewriteRuleSetName set, Rule* rule) { + inline void AddRewriteRule(RewriteRuleSetName set, Rule* rule) { rewrite_rules_map_[static_cast(set)].emplace_back(rule); } - std::vector> &GetTransformationRules() { + std::vector>> &GetTransformationRules() { return transformation_rules_; } - std::vector> &GetImplementationRules() { + std::vector>> &GetImplementationRules() { return implementation_rules_; } - std::vector> &GetRewriteRulesByName( + std::vector>> &GetRewriteRulesByName( RewriteRuleSetName set) { return rewrite_rules_map_[static_cast(set)]; } - std::unordered_map>> &GetRewriteRulesMap() { return rewrite_rules_map_; } + std::unordered_map>>> &GetRewriteRulesMap() { + return rewrite_rules_map_; + } - std::vector> &GetPredicatePushDownRules() { return predicate_push_down_rules_; } + std::vector>> &GetPredicatePushDownRules() { + return predicate_push_down_rules_; + } private: - std::vector> transformation_rules_; - std::vector> implementation_rules_; - std::unordered_map>> rewrite_rules_map_; - std::vector> predicate_push_down_rules_; + std::vector>> transformation_rules_; + std::vector>> implementation_rules_; + std::unordered_map>>> rewrite_rules_map_; + std::vector>> predicate_push_down_rules_; }; } // namespace optimizer diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 57902e744a9..810c3b8e8bb 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -26,32 +26,32 @@ namespace optimizer { /** * @brief (A join B) -> (B join A) */ -class InnerJoinCommutativity : public Rule { +class InnerJoinCommutativity : public Rule { public: InnerJoinCommutativity(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (A join B) join C -> A join (B join C) */ -class InnerJoinAssociativity : public Rule { +class InnerJoinAssociativity : public Rule { public: InnerJoinAssociativity(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; //===--------------------------------------------------------------------===// @@ -61,239 +61,239 @@ class InnerJoinAssociativity : public Rule { /** * @brief (Logical Scan -> Sequential Scan) */ -class GetToSeqScan : public Rule { +class GetToSeqScan : public Rule { public: GetToSeqScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; -class LogicalExternalFileGetToPhysical : public Rule { +class LogicalExternalFileGetToPhysical : public Rule { public: LogicalExternalFileGetToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Generate dummy scan for queries like "SELECT 1", there's no actual * table to generate */ -class GetToDummyScan : public Rule { +class GetToDummyScan : public Rule { public: GetToDummyScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Scan -> Index Scan) */ -class GetToIndexScan : public Rule { +class GetToIndexScan : public Rule { public: GetToIndexScan(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Transforming query derived scan for nested query */ -class LogicalQueryDerivedGetToPhysical : public Rule { +class LogicalQueryDerivedGetToPhysical : public Rule { public: LogicalQueryDerivedGetToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Delete -> Physical Delete) */ -class LogicalDeleteToPhysical : public Rule { +class LogicalDeleteToPhysical : public Rule { public: LogicalDeleteToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Update -> Physical Update) */ -class LogicalUpdateToPhysical : public Rule { +class LogicalUpdateToPhysical : public Rule { public: LogicalUpdateToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Insert -> Physical Insert) */ -class LogicalInsertToPhysical : public Rule { +class LogicalInsertToPhysical : public Rule { public: LogicalInsertToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Insert Select -> Physical Insert Select) */ -class LogicalInsertSelectToPhysical : public Rule { +class LogicalInsertSelectToPhysical : public Rule { public: LogicalInsertSelectToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Group by -> Hash Group by) */ -class LogicalGroupByToHashGroupBy : public Rule { +class LogicalGroupByToHashGroupBy : public Rule { public: LogicalGroupByToHashGroupBy(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Aggregate -> Physical Aggregate) */ -class LogicalAggregateToPhysical : public Rule { +class LogicalAggregateToPhysical : public Rule { public: LogicalAggregateToPhysical(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Inner Join -> Inner Nested-Loop Join) */ -class InnerJoinToInnerNLJoin : public Rule { +class InnerJoinToInnerNLJoin : public Rule { public: InnerJoinToInnerNLJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Inner Join -> Inner Hash Join) */ -class InnerJoinToInnerHashJoin : public Rule { +class InnerJoinToInnerHashJoin : public Rule { public: InnerJoinToInnerHashJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Distinct -> Physical Distinct) */ -class ImplementDistinct : public Rule { +class ImplementDistinct : public Rule { public: ImplementDistinct(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief (Logical Limit -> Physical Limit) */ -class ImplementLimit : public Rule { +class ImplementLimit : public Rule { public: ImplementLimit(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Logical Export to External File -> Physical Export to External file */ -class LogicalExportToPhysicalExport : public Rule { +class LogicalExportToPhysicalExport : public Rule { public: LogicalExportToPhysicalExport(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; //===--------------------------------------------------------------------===// @@ -306,63 +306,63 @@ class LogicalExportToPhysicalExport : public Rule { * we could push "test.a=5" through the join to evaluate at the table scan * level */ -class PushFilterThroughJoin : public Rule { +class PushFilterThroughJoin : public Rule { public: PushFilterThroughJoin(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Combine multiple filters into one single filter using conjunction */ -class CombineConsecutiveFilter : public Rule { +class CombineConsecutiveFilter : public Rule { public: CombineConsecutiveFilter(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief perform predicate push-down to push a filter through aggregation, also * will embed filter into aggregation operator if appropriate. */ -class PushFilterThroughAggregation : public Rule { +class PushFilterThroughAggregation : public Rule { public: PushFilterThroughAggregation(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /** * @brief Embed a filter into a scan operator. After predicate push-down, we * eliminate all filters in the operator trees, predicates should be associated * with get or join */ -class EmbedFilterIntoGet : public Rule { +class EmbedFilterIntoGet : public Rule { public: EmbedFilterIntoGet(); bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// @@ -377,69 +377,69 @@ enum class UnnestPromise { Low = 1, High }; // should not use the following rules in the rewrite phase /////////////////////////////////////////////////////////////////////////////// /// MarkJoinGetToInnerJoin -class MarkJoinToInnerJoin : public Rule { +class MarkJoinToInnerJoin : public Rule { public: MarkJoinToInnerJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// SingleJoinToInnerJoin -class SingleJoinToInnerJoin : public Rule { +class SingleJoinToInnerJoin : public Rule { public: SingleJoinToInnerJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// PullFilterThroughMarkJoin -class PullFilterThroughMarkJoin : public Rule { +class PullFilterThroughMarkJoin : public Rule { public: PullFilterThroughMarkJoin(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// /// PullFilterThroughAggregation -class PullFilterThroughAggregation : public Rule { +class PullFilterThroughAggregation : public Rule { public: PullFilterThroughAggregation(); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + OptimizeContext *context) const override; void Transform(std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const override; + OptimizeContext *context) const override; }; } // namespace optimizer } // namespace peloton diff --git a/src/include/optimizer/stats/child_stats_deriver.h b/src/include/optimizer/stats/child_stats_deriver.h index d0c72f9bf9b..cfca18e30d9 100644 --- a/src/include/optimizer/stats/child_stats_deriver.h +++ b/src/include/optimizer/stats/child_stats_deriver.h @@ -21,15 +21,19 @@ class AbstractExpression; } namespace optimizer { +template class Memo; +class OperatorExpression; + // Derive child stats that has not yet been calculated for a logical group // expression class ChildStatsDeriver : public OperatorVisitor { public: std::vector DeriveInputStats( - GroupExpression *gexpr, - ExprSet required_cols, Memo *memo); + GroupExpression *gexpr, + ExprSet required_cols, + Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalInnerJoin *) override; @@ -43,8 +47,8 @@ class ChildStatsDeriver : public OperatorVisitor { void PassDownRequiredCols(); void PassDownColumn(expression::AbstractExpression* col); ExprSet required_cols_; - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; std::vector output_; }; diff --git a/src/include/optimizer/stats/stats_calculator.h b/src/include/optimizer/stats/stats_calculator.h index 9637db2f224..6fed68370f9 100644 --- a/src/include/optimizer/stats/stats_calculator.h +++ b/src/include/optimizer/stats/stats_calculator.h @@ -17,8 +17,10 @@ namespace peloton { namespace optimizer { +template class Memo; class TableStats; +class OperatorExpression; /** * @brief Derive stats for the root group using a group expression's children's @@ -26,8 +28,10 @@ class TableStats; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, concurrency::TransactionContext* txn); + void CalculateStats(GroupExpression *gexpr, + ExprSet required_cols, + Memo *memo, + concurrency::TransactionContext* txn); void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; @@ -72,9 +76,9 @@ class StatsCalculator : public OperatorVisitor { const std::shared_ptr predicate_table_stats, const expression::AbstractExpression *expr); - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; - Memo *memo_; + Memo *memo_; concurrency::TransactionContext* txn_; }; diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 9651ce8102c..e0091d2d0b1 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -22,19 +22,25 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Binding Iterator //===--------------------------------------------------------------------===// -GroupBindingIterator::GroupBindingIterator(Memo &memo, GroupID id, - std::shared_ptr pattern) - : BindingIterator(memo), +template +GroupBindingIterator::GroupBindingIterator( + Memo &memo, + GroupID id, + std::shared_ptr> pattern) + : BindingIterator(memo), group_id_(id), pattern_(pattern), - target_group_(memo_.GetGroupByID(id)), + target_group_(this->memo_.GetGroupByID(id)), num_group_items_(target_group_->GetLogicalExpressions().size()), current_item_index_(0) { LOG_TRACE("Attempting to bind on group %d", id); } -bool GroupBindingIterator::HasNext() { +template +bool GroupBindingIterator::HasNext() { LOG_TRACE("HasNext"); + + //(TODO): GroupBindingIterator::HasNext() probably needs specialization if (pattern_->Type() == OpType::Leaf) { return current_item_index_ == 0; } @@ -50,8 +56,8 @@ bool GroupBindingIterator::HasNext() { if (current_iterator_ == nullptr) { // Keep checking item iterators until we find a match while (current_item_index_ < num_group_items_) { - current_iterator_.reset(new GroupExprBindingIterator( - memo_, + current_iterator_.reset(new GroupExprBindingIterator( + this->memo_, target_group_->GetLogicalExpressions()[current_item_index_].get(), pattern_)); @@ -67,10 +73,12 @@ bool GroupBindingIterator::HasNext() { return current_iterator_ != nullptr; } -std::shared_ptr GroupBindingIterator::Next() { +template +std::shared_ptr GroupBindingIterator::Next() { + //(TODO): GroupBindingIterator::Next() probably needs specialization if (pattern_->Type() == OpType::Leaf) { current_item_index_ = num_group_items_; - return std::make_shared(LeafOperator::make(group_id_)); + return std::make_shared(LeafOperator::make(group_id_)); } return current_iterator_->Next(); } @@ -78,20 +86,23 @@ std::shared_ptr GroupBindingIterator::Next() { //===--------------------------------------------------------------------===// // Item Binding Iterator //===--------------------------------------------------------------------===// -GroupExprBindingIterator::GroupExprBindingIterator( - Memo &memo, GroupExpression *gexpr, std::shared_ptr pattern) - : BindingIterator(memo), +template +GroupExprBindingIterator::GroupExprBindingIterator( + Memo &memo, + GroupExpression *gexpr, + std::shared_ptr> pattern) + : BindingIterator(memo), gexpr_(gexpr), pattern_(pattern), first_(true), has_next_(false), - current_binding_(std::make_shared(gexpr->Op())) { + current_binding_(std::make_shared(gexpr->Op())) { if (gexpr->Op().GetType() != pattern->Type()) { return; } const std::vector &child_groups = gexpr->GetChildGroupIDs(); - const std::vector> &child_patterns = + const std::vector>> &child_patterns = pattern->Children(); if (child_groups.size() != child_patterns.size()) { @@ -107,9 +118,9 @@ GroupExprBindingIterator::GroupExprBindingIterator( children_bindings_pos_.resize(child_groups.size(), 0); for (size_t i = 0; i < child_groups.size(); ++i) { // Try to find a match in the given group - std::vector> &child_bindings = + std::vector> &child_bindings = children_bindings_[i]; - GroupBindingIterator iterator(memo_, child_groups[i], child_patterns[i]); + GroupBindingIterator iterator(this->memo_, child_groups[i], child_patterns[i]); // Get all bindings while (iterator.HasNext()) { @@ -126,7 +137,8 @@ GroupExprBindingIterator::GroupExprBindingIterator( has_next_ = true; } -bool GroupExprBindingIterator::HasNext() { +template +bool GroupExprBindingIterator::HasNext() { LOG_TRACE("HasNext"); if (has_next_ && first_) { first_ = false; @@ -137,8 +149,7 @@ bool GroupExprBindingIterator::HasNext() { // The first child to be modified int first_modified_idx = children_bindings_pos_.size() - 1; for (; first_modified_idx >= 0; --first_modified_idx) { - const std::vector> &child_binding = - children_bindings_[first_modified_idx]; + const std::vector> &child_binding = children_bindings_[first_modified_idx]; // Try to increment idx from the back size_t new_pos = ++children_bindings_pos_[first_modified_idx]; @@ -154,17 +165,14 @@ bool GroupExprBindingIterator::HasNext() { has_next_ = false; } else { // Pop all updated childrens - for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); - idx++) { + for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); idx++) { current_binding_->PopChild(); } // Add new children to end for (size_t offset = first_modified_idx; offset < children_bindings_pos_.size(); ++offset) { - const std::vector> &child_binding = - children_bindings_[offset]; - std::shared_ptr binding = - child_binding[children_bindings_pos_[offset]]; + const std::vector> &child_binding = children_bindings_[offset]; + std::shared_ptr binding = child_binding[children_bindings_pos_[offset]]; current_binding_->PushChild(binding); } } @@ -172,9 +180,14 @@ bool GroupExprBindingIterator::HasNext() { return has_next_; } -std::shared_ptr GroupExprBindingIterator::Next() { +template +std::shared_ptr GroupExprBindingIterator::Next() { return current_binding_; } +// Explicitly instantiate +template class GroupBindingIterator; +template class GroupExprBindingIterator; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index b432067fae1..c025eed7dff 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -31,9 +31,9 @@ namespace peloton { namespace optimizer { vector, vector>>> -ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, +ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, shared_ptr requirements, - Memo *memo) { + Memo *memo) { requirements_ = requirements; output_.clear(); memo_ = memo; @@ -218,7 +218,7 @@ void ChildPropertyDeriver::DeriveForJoin() { if (prop->Type() == PropertyType::SORT) { auto sort_prop = prop->As(); size_t sort_col_size = sort_prop->GetSortColumnSize(); - Group *probe_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); + Group *probe_group = memo_->GetGroupByID(gexpr_->GetChildGroupId(1)); bool can_pass_down = true; for (size_t idx = 0; idx < sort_col_size; ++idx) { ExprSet tuples; diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 673a7a1b8bd..5f248a415db 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "optimizer/group.h" +#include "optimizer/operator_expression.h" #include "common/logger.h" @@ -20,13 +21,18 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group //===--------------------------------------------------------------------===// -Group::Group(GroupID id, std::unordered_set table_aliases) +template +Group::Group(GroupID id, std::unordered_set table_aliases) : id_(id), table_aliases_(std::move(table_aliases)) { has_explored_ = false; } -void Group::AddExpression(std::shared_ptr expr, - bool enforced) { +template +void Group::AddExpression( + std::shared_ptr> expr, + bool enforced) { + + //(TODO): rethink how separation works with AbstractExpressions // Do duplicate detection expr->SetGroupID(id_); if (enforced) @@ -37,8 +43,12 @@ void Group::AddExpression(std::shared_ptr expr, logical_expressions_.push_back(expr); } -bool Group::SetExpressionCost(GroupExpression *expr, double cost, - std::shared_ptr &properties) { +template +bool Group::SetExpressionCost( + GroupExpression *expr, + double cost, + std::shared_ptr &properties) { + LOG_TRACE("Adding expression cost on group %d with op %s, req %s", expr->GetGroupID(), expr->Op().GetName().c_str(), properties->ToString().c_str()); @@ -51,8 +61,11 @@ bool Group::SetExpressionCost(GroupExpression *expr, double cost, } return false; } -GroupExpression *Group::GetBestExpression( + +template +GroupExpression *Group::GetBestExpression( std::shared_ptr &properties) { + auto it = lowest_cost_expressions_.find(properties); if (it != lowest_cost_expressions_.end()) { return std::get<1>(it->second); @@ -62,20 +75,22 @@ GroupExpression *Group::GetBestExpression( return nullptr; } -bool Group::HasExpressions( - const std::shared_ptr &properties) const { +template +bool Group::HasExpressions(const std::shared_ptr &properties) const { const auto &it = lowest_cost_expressions_.find(properties); return (it != lowest_cost_expressions_.end()); } -std::shared_ptr Group::GetStats(std::string column_name) { +template +std::shared_ptr Group::GetStats(std::string column_name) { if (!stats_.count(column_name)) { return nullptr; } return stats_[column_name]; } -const std::string Group::GetInfo(int num_indent) const { +template +const std::string Group::GetInfo(int num_indent) const { std::ostringstream os; os << StringUtil::Indent(num_indent) << "GroupID: " << GetID() << std::endl; @@ -134,22 +149,28 @@ const std::string Group::GetInfo(int num_indent) const { return os.str(); } -const std::string Group::GetInfo() const { +template +const std::string Group::GetInfo() const { std::ostringstream os; os << GetInfo(0); return os.str(); } -void Group::AddStats(std::string column_name, +template +void Group::AddStats(std::string column_name, std::shared_ptr stats) { PELOTON_ASSERT((size_t)GetNumRows() == stats->num_rows); stats_[column_name] = stats; } -bool Group::HasColumnStats(std::string column_name) { +template +bool Group::HasColumnStats(std::string column_name) { return stats_.count(column_name); } +// Explicitly instantiate +template class Group; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 498c949b583..08c88897e6b 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -21,41 +21,51 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// -GroupExpression::GroupExpression(Operator op, std::vector child_groups) +template +GroupExpression::GroupExpression(Node op, std::vector child_groups) : group_id(UNDEFINED_GROUP), op(op), child_groups(child_groups), stats_derived_(false) {} -GroupID GroupExpression::GetGroupID() const { return group_id; } +template +GroupID GroupExpression::GetGroupID() const { return group_id; } -void GroupExpression::SetGroupID(GroupID id) { group_id = id; } +template +void GroupExpression::SetGroupID(GroupID id) { group_id = id; } -void GroupExpression::SetChildGroupID(int child_group_idx, GroupID group_id) { +template +void GroupExpression::SetChildGroupID(int child_group_idx, GroupID group_id) { child_groups[child_group_idx] = group_id; } -const std::vector &GroupExpression::GetChildGroupIDs() const { +template +const std::vector &GroupExpression::GetChildGroupIDs() const { return child_groups; } -GroupID GroupExpression::GetChildGroupId(int child_idx) const { +template +GroupID GroupExpression::GetChildGroupId(int child_idx) const { return child_groups[child_idx]; } -Operator GroupExpression::Op() const { return op; } +template +Node GroupExpression::Op() const { return op; } -double GroupExpression::GetCost( +template +double GroupExpression::GetCost( std::shared_ptr &requirements) const { return std::get<0>(lowest_cost_table_.find(requirements)->second); } -std::vector> GroupExpression::GetInputProperties( +template +std::vector> GroupExpression::GetInputProperties( std::shared_ptr requirements) const { return std::get<1>(lowest_cost_table_.find(requirements)->second); } -void GroupExpression::SetLocalHashTable( +template +void GroupExpression::SetLocalHashTable( const std::shared_ptr &output_properties, const std::vector> &input_properties_list, double cost) { @@ -73,7 +83,8 @@ void GroupExpression::SetLocalHashTable( } } -hash_t GroupExpression::Hash() const { +template +hash_t GroupExpression::Hash() const { size_t hash = op.Hash(); for (size_t i = 0; i < child_groups.size(); ++i) { @@ -84,17 +95,23 @@ hash_t GroupExpression::Hash() const { return hash; } -bool GroupExpression::operator==(const GroupExpression &r) { +template +bool GroupExpression::operator==(const GroupExpression &r) { return (op == r.Op()) && (child_groups == r.child_groups); } -void GroupExpression::SetRuleExplored(Rule *rule) { +template +void GroupExpression::SetRuleExplored(Rule *rule) { rule_mask_.set(rule->GetRuleIdx(), true); } -bool GroupExpression::HasRuleExplored(Rule *rule) { +template +bool GroupExpression::HasRuleExplored(Rule *rule) { return rule_mask_.test(rule->GetRuleIdx()); } +// Explicitly instantiate to prevent linker errors +template class GroupExpression; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index fdffb7e79a6..30ee095a379 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -37,8 +37,9 @@ InputColumnDeriver::InputColumnDeriver() {} pair, vector>> InputColumnDeriver::DeriveInputColumns( - GroupExpression *gexpr, shared_ptr properties, - vector required_cols, Memo *memo) { + GroupExpression *gexpr, shared_ptr properties, + vector required_cols, + Memo *memo) { properties_ = properties; gexpr_ = gexpr; required_cols_ = move(required_cols); diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index ca68a52c1d0..5f86f988fa3 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -21,15 +21,77 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Memo //===--------------------------------------------------------------------===// -Memo::Memo() {} +template +Memo::Memo() {} + +//===--------------------------------------------------------------------===// +// Memo::AddNewGroup (declare here to prevent specialization error) +//===--------------------------------------------------------------------===// +template +GroupID Memo::AddNewGroup(std::shared_ptr> gexpr) { + //(TODO): handle general case/AbstractExpressions + (void)gexpr; + PELOTON_ASSERT(0); + return 0; +} + +template <> +GroupID Memo::AddNewGroup(std::shared_ptr> gexpr) { + GroupID new_group_id = groups_.size(); + // Find out the table alias that this group represents + std::unordered_set table_aliases; + auto op_type = gexpr->Op().GetType(); + if (op_type == OpType::Get) { + // For base group, the table alias can get directly from logical get + const LogicalGet *logical_get = gexpr->Op().As(); + table_aliases.insert(logical_get->table_alias); + } else if (op_type == OpType::LogicalQueryDerivedGet) { + const LogicalQueryDerivedGet *query_get = + gexpr->Op().As(); + table_aliases.insert(query_get->table_alias); + } else { + // For other groups, need to aggregate the table alias from children + for (auto child_group_id : gexpr->GetChildGroupIDs()) { + Group *child_group = GetGroupByID(child_group_id); + for (auto &table_alias : child_group->GetTableAliases()) { + table_aliases.insert(table_alias); + } + } + } + + groups_.emplace_back( + new Group(new_group_id, std::move(table_aliases))); + return new_group_id; +} + +//===--------------------------------------------------------------------===// +// Memo remaining interface functions +//===--------------------------------------------------------------------===// +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + bool enforced) { -GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, - bool enforced) { return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); } -GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, - GroupID target_group, bool enforced) { +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, + bool enforced) { + + //(TODO): handle general/AbstractExpression case + PELOTON_ASSERT(0); + return nullptr; +} + +// Specialization for Memo::InsertExpression due to OpType +template <> +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, + bool enforced) { // If leaf, then just return if (gexpr->Op().GetType() == OpType::Leaf) { const LeafOperator *leaf = gexpr->Op().As(); @@ -55,19 +117,22 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, } else { group_id = target_group; } - Group *group = GetGroupByID(group_id); + Group *group = GetGroupByID(group_id); group->AddExpression(gexpr, enforced); return gexpr.get(); } } -std::vector> &Memo::Groups() { +template +std::vector>> &Memo::Groups() { return groups_; } -Group *Memo::GetGroupByID(GroupID id) { return groups_[id].get(); } +template +Group *Memo::GetGroupByID(GroupID id) { return groups_[id].get(); } -const std::string Memo::GetInfo(int num_indent) const { +template +const std::string Memo::GetInfo(int num_indent) const { std::ostringstream os; os << StringUtil::Indent(num_indent) << "Memo::\n"; os << StringUtil::Indent(num_indent + 1) @@ -80,40 +145,15 @@ const std::string Memo::GetInfo(int num_indent) const { return os.str(); } -const std::string Memo::GetInfo() const { +template +const std::string Memo::GetInfo() const { std::ostringstream os; os << GetInfo(0); return os.str(); } - -GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { - GroupID new_group_id = groups_.size(); - // Find out the table alias that this group represents - std::unordered_set table_aliases; - auto op_type = gexpr->Op().GetType(); - if (op_type == OpType::Get) { - // For base group, the table alias can get directly from logical get - const LogicalGet *logical_get = gexpr->Op().As(); - table_aliases.insert(logical_get->table_alias); - } else if (op_type == OpType::LogicalQueryDerivedGet) { - const LogicalQueryDerivedGet *query_get = - gexpr->Op().As(); - table_aliases.insert(query_get->table_alias); - } else { - // For other groups, need to aggregate the table alias from children - for (auto child_group_id : gexpr->GetChildGroupIDs()) { - Group *child_group = GetGroupByID(child_group_id); - for (auto &table_alias : child_group->GetTableAliases()) { - table_aliases.insert(table_alias); - } - } - } - - groups_.emplace_back( - new Group(new_group_id, std::move(table_aliases))); - return new_group_id; -} +// Explicitly instantiate template +template class Memo; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 83bcadde4de..5c6d8ac304c 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -65,15 +65,15 @@ Optimizer::Optimizer(const CostModels cost_model) : metadata_(nullptr) { switch (cost_model) { case CostModels::DEFAULT: { - metadata_ = OptimizerMetadata(std::unique_ptr(new DefaultCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new DefaultCostModel)); break; } case CostModels::POSTGRES: { - metadata_ = OptimizerMetadata(std::unique_ptr(new PostgresCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new PostgresCostModel)); break; } case CostModels::TRIVIAL: { - metadata_ = OptimizerMetadata(std::unique_ptr(new TrivialCostModel)); + metadata_ = OptimizerMetadata(std::unique_ptr(new TrivialCostModel)); break; } default: @@ -83,17 +83,17 @@ Optimizer::Optimizer(const CostModels cost_model) : metadata_(nullptr) { void Optimizer::OptimizeLoop(int root_group_id, std::shared_ptr required_props) { - std::shared_ptr root_context = - std::make_shared(&metadata_, required_props); + std::shared_ptr> root_context = + std::make_shared>(&metadata_, required_props); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); metadata_.SetTaskPool(task_stack.get()); // Perform rewrite first - task_stack->Push(new TopDownRewrite(root_group_id, root_context, + task_stack->Push(new TopDownRewrite(root_group_id, root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); - task_stack->Push(new BottomUpRewrite( + task_stack->Push(new BottomUpRewrite( root_group_id, root_context, RewriteRuleSetName::UNNEST_SUBQUERY, false)); ExecuteTaskStack(*task_stack, root_group_id, root_context); @@ -132,7 +132,7 @@ shared_ptr Optimizer::BuildPelotonPlanTree( metadata_.txn = txn; // Generate initial operator tree from query tree - shared_ptr gexpr = InsertQueryTree(parse_tree, txn); + shared_ptr> gexpr = InsertQueryTree(parse_tree, txn); GroupID root_id = gexpr->GetGroupID(); // Get the physical properties the final plan must output auto query_info = GetQueryInfo(parse_tree); @@ -158,7 +158,7 @@ shared_ptr Optimizer::BuildPelotonPlanTree( } void Optimizer::Reset() { - metadata_ = OptimizerMetadata(std::move(metadata_.cost_model)); + metadata_ = OptimizerMetadata(std::move(metadata_.cost_model)); } unique_ptr Optimizer::HandleDDLStatement( @@ -247,12 +247,12 @@ unique_ptr Optimizer::HandleDDLStatement( return ddl_plan; } -shared_ptr Optimizer::InsertQueryTree( +shared_ptr> Optimizer::InsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn) { QueryToOperatorTransformer converter(txn); shared_ptr initial = converter.ConvertToOpExpression(tree); - shared_ptr gexpr; + shared_ptr> gexpr; metadata_.RecordTransformedExpression(initial, gexpr); return gexpr; } @@ -323,7 +323,7 @@ const std::string Optimizer::GetOperatorInfo( int num_indent) { std::ostringstream os; - Group *group = metadata_.memo.GetGroupByID(id); + Group *group = metadata_.memo.GetGroupByID(id); auto gexpr = group->GetBestExpression(required_props); os << std::endl << StringUtil::Indent(num_indent) << "operator name: " @@ -347,7 +347,7 @@ const std::string Optimizer::GetOperatorInfo( unique_ptr Optimizer::ChooseBestPlan( GroupID id, std::shared_ptr required_props, std::vector required_cols) { - Group *group = metadata_.memo.GetGroupByID(id); + Group *group = metadata_.memo.GetGroupByID(id); LOG_TRACE("Choosing with property : %s", required_props->ToString().c_str()); auto gexpr = group->GetBestExpression(required_props); @@ -395,8 +395,8 @@ unique_ptr Optimizer::ChooseBestPlan( } void Optimizer::ExecuteTaskStack( - OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr root_context) { + OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr> root_context) { auto root_group = metadata_.memo.GetGroupByID(root_group_id); auto &timer = metadata_.timer; const auto timeout_limit = metadata_.timeout_limit; diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index e1cfac5643d..8026828ee2d 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -25,10 +25,12 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Base class //===--------------------------------------------------------------------===// -void OptimizerTask::ConstructValidRules( - GroupExpression *group_expr, OptimizeContext *context, - std::vector> &rules, - std::vector &valid_rules) { +template +void OptimizerTask::ConstructValidRules( + GroupExpression *group_expr, + OptimizeContext *context, + std::vector>> &rules, + std::vector> &valid_rules) { for (auto &rule : rules) { // Check if we can apply the rule bool root_pattern_mismatch = @@ -45,13 +47,16 @@ void OptimizerTask::ConstructValidRules( } } -void OptimizerTask::PushTask(OptimizerTask *task) { +template +void OptimizerTask::PushTask(OptimizerTask *task) { context_->metadata->task_pool->Push(task); } -Memo &OptimizerTask::GetMemo() const { return context_->metadata->memo; } +template +Memo &OptimizerTask::GetMemo() const { return context_->metadata->memo; } -RuleSet &OptimizerTask::GetRuleSet() const { +template +RuleSet &OptimizerTask::GetRuleSet() const { return context_->metadata->rule_set; } @@ -86,14 +91,14 @@ void OptimizeGroup::execute() { // OptimizeExpression //===--------------------------------------------------------------------===// void OptimizeExpression::execute() { - std::vector valid_rules; + std::vector> valid_rules; // Construct valid transformation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetTransformationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetTransformationRules(), valid_rules); // Construct valid implementation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetImplementationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetImplementationRules(), valid_rules); std::sort(valid_rules.begin(), valid_rules.end()); LOG_DEBUG("OptimizeExpression::execute() op %d, valid rules : %lu", @@ -138,7 +143,7 @@ void ExploreGroup::execute() { //===--------------------------------------------------------------------===// void ExploreExpression::execute() { LOG_TRACE("ExploreExpression::execute() "); - std::vector valid_rules; + std::vector> valid_rules; // Construct valid transformation rules from rule set ConstructValidRules(group_expr_, context_.get(), @@ -172,8 +177,8 @@ void ApplyRule::execute() { LOG_TRACE("ApplyRule::execute() for rule: %d", rule_->GetRuleIdx()); if (group_expr_->HasRuleExplored(rule_)) return; - GroupExprBindingIterator iterator(GetMemo(), group_expr_, - rule_->GetMatchPattern()); + GroupExprBindingIterator iterator(GetMemo(), group_expr_, + rule_->GetMatchPattern()); while (iterator.HasNext()) { auto before = iterator.Next(); if (!rule_->Check(before, context_.get())) { @@ -183,7 +188,7 @@ void ApplyRule::execute() { std::vector> after; rule_->Transform(before, after, context_.get()); for (auto &new_expr : after) { - std::shared_ptr new_gexpr; + std::shared_ptr> new_gexpr; if (context_->metadata->RecordTransformedExpression( new_expr, new_gexpr, group_expr_->GetGroupID())) { // A new group expression is generated @@ -315,7 +320,7 @@ void OptimizeInputs::execute() { prev_child_idx_ = cur_child_idx_; PushTask(new OptimizeInputs(this)); PushTask(new OptimizeGroup( - child_group, std::make_shared( + child_group, std::make_shared>( context_->metadata, i_prop, context_->cost_upper_bound - cur_total_cost_))); return; } else { // If we return from OptimizeGroup, then there is no expr for @@ -336,7 +341,7 @@ void OptimizeInputs::execute() { // Enforce property if the requirement does not meet PropertyEnforcer prop_enforcer; auto extended_output_properties = output_prop->Properties(); - GroupExpression *memo_enforced_expr = nullptr; + GroupExpression *memo_enforced_expr = nullptr; bool meet_requirement = true; // TODO: For now, we enforce the missing properties in the order of how we // find them. This may @@ -402,29 +407,30 @@ void OptimizeInputs::execute() { } } -void TopDownRewrite::execute() { - std::vector valid_rules; +template +void TopDownRewrite::execute() { + std::vector> valid_rules; - auto cur_group = GetMemo().GetGroupByID(group_id_); + auto cur_group = this->GetMemo().GetGroupByID(group_id_); auto cur_group_expr = cur_group->GetLogicalExpression(); // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); // Sort so that we apply rewrite rules with higher promise first std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + std::greater>()); for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, + r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -433,8 +439,8 @@ void TopDownRewrite::execute() { // saturated if (!after.empty()) { auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask(new TopDownRewrite(group_id_, context_, rule_set_name_)); + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + this->PushTask(new TopDownRewrite(group_id_, this->context_, rule_set_name_)); return; } } @@ -445,47 +451,48 @@ void TopDownRewrite::execute() { child_group_idx < cur_group_expr->GetChildrenGroupsSize(); child_group_idx++) { // Need to rewrite all sub trees first - PushTask( - new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_)); + this->PushTask( + new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), + this->context_, rule_set_name_)); } } -void BottomUpRewrite::execute() { - std::vector valid_rules; +template +void BottomUpRewrite::execute() { + std::vector> valid_rules; - auto cur_group = GetMemo().GetGroupByID(group_id_); + auto cur_group = this->GetMemo().GetGroupByID(group_id_); auto cur_group_expr = cur_group->GetLogicalExpression(); if (!has_optimized_child_) { - PushTask(new BottomUpRewrite(group_id_, context_, rule_set_name_, true)); + this->PushTask(new BottomUpRewrite(group_id_, this->context_, rule_set_name_, true)); for (size_t child_group_idx = 0; child_group_idx < cur_group_expr->GetChildrenGroupsSize(); child_group_idx++) { // Need to rewrite all sub trees first - PushTask( - new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_, false)); + this->PushTask( + new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), + this->context_, rule_set_name_, false)); } return; } // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); // Sort so that we apply rewrite rules with higher promise first std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + std::greater>()); for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, + r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -494,14 +501,20 @@ void BottomUpRewrite::execute() { // saturated, also childs are already been rewritten if (!after.empty()) { auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask( - new BottomUpRewrite(group_id_, context_, rule_set_name_, false)); + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + this->PushTask( + new BottomUpRewrite(group_id_, this->context_, rule_set_name_, false)); return; } } cur_group_expr->SetRuleExplored(r.rule); } } + + +// Explicitly instantiate +template class TopDownRewrite; +template class BottomUpRewrite; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/pattern.cpp b/src/optimizer/pattern.cpp index d7665d678bb..03ab3858a0d 100644 --- a/src/optimizer/pattern.cpp +++ b/src/optimizer/pattern.cpp @@ -15,17 +15,24 @@ namespace peloton { namespace optimizer { -Pattern::Pattern(OpType op) : _type(op) {} +template +Pattern::Pattern(OperatorType op) : _type(op) {} -void Pattern::AddChild(std::shared_ptr child) { +template +void Pattern::AddChild(std::shared_ptr> child) { children.push_back(child); } -const std::vector> &Pattern::Children() const { +template +const std::vector>> &Pattern::Children() const { return children; } -OpType Pattern::Type() const { return _type; } +template +OperatorType Pattern::Type() const { return _type; } + +// Explicitly instantiate +template class Pattern; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/property_enforcer.cpp b/src/optimizer/property_enforcer.cpp index 834cf9a76d7..98013f214f4 100644 --- a/src/optimizer/property_enforcer.cpp +++ b/src/optimizer/property_enforcer.cpp @@ -19,8 +19,10 @@ namespace peloton { namespace optimizer { -std::shared_ptr PropertyEnforcer::EnforceProperty( - GroupExpression* gexpr, Property* property) { +std::shared_ptr> PropertyEnforcer::EnforceProperty( + GroupExpression* gexpr, + Property* property) { + input_gexpr_ = gexpr; property->Accept(this); return output_gexpr_; @@ -33,13 +35,13 @@ void PropertyEnforcer::Visit(const PropertyColumns *) { void PropertyEnforcer::Visit(const PropertySort *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalOrderBy::make(), child_groups); + std::make_shared>(PhysicalOrderBy::make(), child_groups); } void PropertyEnforcer::Visit(const PropertyDistinct *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalDistinct::make(), child_groups); + std::make_shared>(PhysicalDistinct::make(), child_groups); } void PropertyEnforcer::Visit(const PropertyLimit *) {} diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8c72ed17fa8..8d39ab3b94b 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -16,7 +16,24 @@ namespace peloton { namespace optimizer { -int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { +template +int Rule::Promise( + GroupExpression *group_expr, + OptimizeContext *context) const { + + //(TODO): handle general/AbstractExpression case + PELOTON_ASSERT(group_expr); + PELOTON_ASSERT(context); + PELOTON_ASSERT(0); + return 0; +} + +// Specialization due to OpType +template <> +int Rule::Promise( + GroupExpression *group_expr, + OptimizeContext *context) const { + (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -27,7 +44,14 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { return LOG_PROMISE; } -RuleSet::RuleSet() { +template +RuleSet::RuleSet() { + //(TODO): handle general/AbstractExpression case + PELOTON_ASSERT(0); +} + +template <> +RuleSet::RuleSet() { AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 8574e00f337..ed24b5680ec 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -34,15 +34,15 @@ namespace optimizer { InnerJoinCommutativity::InnerJoinCommutativity() { type_ = RuleType::INNER_JOIN_COMMUTE; - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::InnerJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } bool InnerJoinCommutativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -51,7 +51,7 @@ bool InnerJoinCommutativity::Check(std::shared_ptr expr, void InnerJoinCommutativity::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto join_op = input->Op().As(); auto join_predicates = std::vector(join_op->join_predicates); @@ -74,20 +74,20 @@ InnerJoinAssociativity::InnerJoinAssociativity() { type_ = RuleType::INNER_JOIN_ASSOCIATE; // Create left nested join - auto left_child = std::make_shared(OpType::InnerJoin); - left_child->AddChild(std::make_shared(OpType::Leaf)); - left_child->AddChild(std::make_shared(OpType::Leaf)); + auto left_child = std::make_shared>(OpType::InnerJoin); + left_child->AddChild(std::make_shared>(OpType::Leaf)); + left_child->AddChild(std::make_shared>(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); match_pattern->AddChild(left_child); match_pattern->AddChild(right_child); } // TODO: As far as I know, theres nothing else that needs to be checked bool InnerJoinAssociativity::Check(std::shared_ptr expr, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -96,7 +96,7 @@ bool InnerJoinAssociativity::Check(std::shared_ptr expr, void InnerJoinAssociativity::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN // right) Variables are named accordingly to above transformation auto parent_join = input->Op().As(); @@ -179,11 +179,11 @@ void InnerJoinAssociativity::Transform( GetToDummyScan::GetToDummyScan() { type_ = RuleType::GET_TO_DUMMY_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToDummyScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalGet *get = plan->Op().As(); return get->table == nullptr; @@ -192,7 +192,7 @@ bool GetToDummyScan::Check(std::shared_ptr plan, void GetToDummyScan::Transform( UNUSED_ATTRIBUTE std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result_plan = std::make_shared(DummyScan::make()); transformed.push_back(result_plan); @@ -203,11 +203,11 @@ void GetToDummyScan::Transform( GetToSeqScan::GetToSeqScan() { type_ = RuleType::GET_TO_SEQ_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToSeqScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalGet *get = plan->Op().As(); return get->table != nullptr; @@ -216,7 +216,7 @@ bool GetToSeqScan::Check(std::shared_ptr plan, void GetToSeqScan::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalGet *get = input->Op().As(); auto result_plan = std::make_shared( @@ -235,11 +235,11 @@ void GetToSeqScan::Transform( GetToIndexScan::GetToIndexScan() { type_ = RuleType::GET_TO_INDEX_SCAN; - match_pattern = std::make_shared(OpType::Get); + match_pattern = std::make_shared>(OpType::Get); } bool GetToIndexScan::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { // If there is a index for the table, return true, // else return false (void)context; @@ -255,7 +255,7 @@ bool GetToIndexScan::Check(std::shared_ptr plan, void GetToIndexScan::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); @@ -409,13 +409,13 @@ void GetToIndexScan::Transform( /// LogicalQueryDerivedGetToPhysical LogicalQueryDerivedGetToPhysical::LogicalQueryDerivedGetToPhysical() { type_ = RuleType::QUERY_DERIVED_GET_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalQueryDerivedGet); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalQueryDerivedGet); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalQueryDerivedGetToPhysical::Check( - std::shared_ptr expr, OptimizeContext *context) const { + std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; return true; @@ -424,7 +424,7 @@ bool LogicalQueryDerivedGetToPhysical::Check( void LogicalQueryDerivedGetToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalQueryDerivedGet *get = input->Op().As(); auto result_plan = @@ -439,19 +439,19 @@ void LogicalQueryDerivedGetToPhysical::Transform( /// LogicalExternalFileGetToPhysical LogicalExternalFileGetToPhysical::LogicalExternalFileGetToPhysical() { type_ = RuleType::EXTERNAL_FILE_GET_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalExternalFileGet); + match_pattern = std::make_shared>(OpType::LogicalExternalFileGet); } bool LogicalExternalFileGetToPhysical::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExternalFileGetToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const auto *get = input->Op().As(); auto result_plan = std::make_shared( @@ -467,13 +467,13 @@ void LogicalExternalFileGetToPhysical::Transform( /// LogicalDeleteToPhysical LogicalDeleteToPhysical::LogicalDeleteToPhysical() { type_ = RuleType::DELETE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalDelete); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalDelete); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -482,7 +482,7 @@ bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, void LogicalDeleteToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalDelete *delete_op = input->Op().As(); auto result = std::make_shared( PhysicalDelete::make(delete_op->target_table)); @@ -495,13 +495,13 @@ void LogicalDeleteToPhysical::Transform( /// LogicalUpdateToPhysical LogicalUpdateToPhysical::LogicalUpdateToPhysical() { type_ = RuleType::UPDATE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalUpdate); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalUpdate); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -510,7 +510,7 @@ bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, void LogicalUpdateToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalUpdate *update_op = input->Op().As(); auto result = std::make_shared( PhysicalUpdate::make(update_op->target_table, update_op->updates)); @@ -523,13 +523,13 @@ void LogicalUpdateToPhysical::Transform( /// LogicalInsertToPhysical LogicalInsertToPhysical::LogicalInsertToPhysical() { type_ = RuleType::INSERT_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalInsert); - // std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalInsert); + // std::shared_ptr> child(std::make_shared>(OpType::Leaf)); // match_pattern->AddChild(child); } bool LogicalInsertToPhysical::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -538,7 +538,7 @@ bool LogicalInsertToPhysical::Check(std::shared_ptr plan, void LogicalInsertToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalInsert *insert_op = input->Op().As(); auto result = std::make_shared(PhysicalInsert::make( insert_op->target_table, insert_op->columns, insert_op->values)); @@ -551,13 +551,13 @@ void LogicalInsertToPhysical::Transform( /// LogicalInsertSelectToPhysical LogicalInsertSelectToPhysical::LogicalInsertSelectToPhysical() { type_ = RuleType::INSERT_SELECT_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalInsertSelect); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalInsertSelect); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalInsertSelectToPhysical::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; return true; @@ -566,7 +566,7 @@ bool LogicalInsertSelectToPhysical::Check( void LogicalInsertSelectToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalInsertSelect *insert_op = input->Op().As(); auto result = std::make_shared( PhysicalInsertSelect::make(insert_op->target_table)); @@ -579,14 +579,14 @@ void LogicalInsertSelectToPhysical::Transform( /// LogicalAggregateAndGroupByToHashGroupBy LogicalGroupByToHashGroupBy::LogicalGroupByToHashGroupBy() { type_ = RuleType::AGGREGATE_TO_HASH_AGGREGATE; - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalGroupByToHashGroupBy::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = plan->Op().As(); @@ -596,7 +596,7 @@ bool LogicalGroupByToHashGroupBy::Check( void LogicalGroupByToHashGroupBy::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalAggregateAndGroupBy *agg_op = input->Op().As(); auto result = std::make_shared( @@ -610,14 +610,14 @@ void LogicalGroupByToHashGroupBy::Transform( /// LogicalAggregateToPhysical LogicalAggregateToPhysical::LogicalAggregateToPhysical() { type_ = RuleType::AGGREGATE_TO_PLAIN_AGGREGATE; - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); - std::shared_ptr child(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); + std::shared_ptr> child(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool LogicalAggregateToPhysical::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = plan->Op().As(); @@ -627,7 +627,7 @@ bool LogicalAggregateToPhysical::Check( void LogicalAggregateToPhysical::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result = std::make_shared(PhysicalAggregate::make()); PELOTON_ASSERT(input->Children().size() == 1); result->PushChild(input->Children().at(0)); @@ -640,11 +640,11 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { type_ = RuleType::INNER_JOIN_TO_NL_JOIN; // TODO NLJoin currently only support left deep tree - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); // Add node - we match join relation R and S match_pattern->AddChild(left_child); @@ -654,7 +654,7 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { } bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -663,7 +663,7 @@ bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, void InnerJoinToInnerNLJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join const LogicalInnerJoin *inner_join = input->Op().As(); @@ -701,11 +701,11 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { type_ = RuleType::INNER_JOIN_TO_HASH_JOIN; // Make three node types for pattern matching - std::shared_ptr left_child(std::make_shared(OpType::Leaf)); - std::shared_ptr right_child(std::make_shared(OpType::Leaf)); + std::shared_ptr> left_child(std::make_shared>(OpType::Leaf)); + std::shared_ptr> right_child(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::InnerJoin); + match_pattern = std::make_shared>(OpType::InnerJoin); // Add node - we match join relation R and S as well as the predicate exp match_pattern->AddChild(left_child); @@ -715,7 +715,7 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { } bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -724,7 +724,7 @@ bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, void InnerJoinToInnerHashJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join const LogicalInnerJoin *inner_join = input->Op().As(); @@ -761,12 +761,12 @@ void InnerJoinToInnerHashJoin::Transform( ImplementDistinct::ImplementDistinct() { type_ = RuleType::IMPLEMENT_DISTINCT; - match_pattern = std::make_shared(OpType::LogicalDistinct); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalDistinct); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool ImplementDistinct::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -775,7 +775,7 @@ bool ImplementDistinct::Check(std::shared_ptr plan, void ImplementDistinct::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; auto result_plan = std::make_shared(PhysicalDistinct::make()); @@ -792,12 +792,12 @@ void ImplementDistinct::Transform( ImplementLimit::ImplementLimit() { type_ = RuleType::IMPLEMENT_LIMIT; - match_pattern = std::make_shared(OpType::LogicalLimit); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalLimit); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool ImplementLimit::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -806,7 +806,7 @@ bool ImplementLimit::Check(std::shared_ptr plan, void ImplementLimit::Transform( std::shared_ptr input, std::vector> &transformed, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; const LogicalLimit *limit_op = input->Op().As(); @@ -825,20 +825,20 @@ void ImplementLimit::Transform( /// LogicalExport to Physical Export LogicalExportToPhysicalExport::LogicalExportToPhysicalExport() { type_ = RuleType::EXPORT_EXTERNAL_FILE_TO_PHYSICAL; - match_pattern = std::make_shared(OpType::LogicalExportExternalFile); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalExportExternalFile); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } bool LogicalExportToPhysicalExport::Check( UNUSED_ATTRIBUTE std::shared_ptr plan, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExportToPhysicalExport::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { const auto *export_op = input->Op().As(); auto result_plan = @@ -863,26 +863,26 @@ PushFilterThroughJoin::PushFilterThroughJoin() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; // Make three node types for pattern matching - std::shared_ptr child(std::make_shared(OpType::InnerJoin)); - child->AddChild(std::make_shared(OpType::Leaf)); - child->AddChild(std::make_shared(OpType::Leaf)); + std::shared_ptr> child(std::make_shared>(OpType::InnerJoin)); + child->AddChild(std::make_shared>(OpType::Leaf)); + child->AddChild(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::LogicalFilter); + match_pattern = std::make_shared>(OpType::LogicalFilter); // Add node - we match join relation R and S as well as the predicate exp match_pattern->AddChild(child); } bool PushFilterThroughJoin::Check(std::shared_ptr, - OptimizeContext *) const { + OptimizeContext *) const { return true; } void PushFilterThroughJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughJoin::Transform"); auto &memo = context->metadata->memo; auto join_op_expr = input->Children().at(0); @@ -955,26 +955,26 @@ void PushFilterThroughJoin::Transform( PushFilterThroughAggregation::PushFilterThroughAggregation() { type_ = RuleType::PUSH_FILTER_THROUGH_JOIN; - std::shared_ptr child( - std::make_shared(OpType::LogicalAggregateAndGroupBy)); - child->AddChild(std::make_shared(OpType::Leaf)); + std::shared_ptr> child( + std::make_shared>(OpType::LogicalAggregateAndGroupBy)); + child->AddChild(std::make_shared>(OpType::Leaf)); // Initialize a pattern for optimizer to match - match_pattern = std::make_shared(OpType::LogicalFilter); + match_pattern = std::make_shared>(OpType::LogicalFilter); // Add node - we match (filter)->(aggregation)->(leaf) match_pattern->AddChild(child); } bool PushFilterThroughAggregation::Check(std::shared_ptr, - OptimizeContext *) const { + OptimizeContext *) const { return true; } void PushFilterThroughAggregation::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughAggregation::Transform"); auto aggregation_op = input->Children().at(0)->Op().As(); @@ -1022,16 +1022,16 @@ void PushFilterThroughAggregation::Transform( CombineConsecutiveFilter::CombineConsecutiveFilter() { type_ = RuleType::COMBINE_CONSECUTIVE_FILTER; - match_pattern = std::make_shared(OpType::LogicalFilter); - std::shared_ptr child( - std::make_shared(OpType::LogicalFilter)); - child->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalFilter); + std::shared_ptr> child( + std::make_shared>(OpType::LogicalFilter)); + child->AddChild(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(child); } bool CombineConsecutiveFilter::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1048,7 +1048,7 @@ bool CombineConsecutiveFilter::Check(std::shared_ptr plan, void CombineConsecutiveFilter::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto child_filter = input->Children()[0]; auto root_predicates = input->Op().As()->predicates; @@ -1071,14 +1071,14 @@ void CombineConsecutiveFilter::Transform( EmbedFilterIntoGet::EmbedFilterIntoGet() { type_ = RuleType::EMBED_FILTER_INTO_GET; - match_pattern = std::make_shared(OpType::LogicalFilter); - std::shared_ptr child(std::make_shared(OpType::Get)); + match_pattern = std::make_shared>(OpType::LogicalFilter); + std::shared_ptr> child(std::make_shared>(OpType::Get)); match_pattern->AddChild(child); } bool EmbedFilterIntoGet::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; return true; @@ -1087,7 +1087,7 @@ bool EmbedFilterIntoGet::Check(std::shared_ptr plan, void EmbedFilterIntoGet::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { auto get = input->Children()[0]->Op().As(); auto predicates = input->Op().As()->predicates; @@ -1105,13 +1105,13 @@ void EmbedFilterIntoGet::Transform( MarkJoinToInnerJoin::MarkJoinToInnerJoin() { type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN; - match_pattern = std::make_shared(OpType::LogicalMarkJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalMarkJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } -int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1122,7 +1122,7 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, } bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1135,7 +1135,7 @@ bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, void MarkJoinToInnerJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("MarkJoinToInnerJoin::Transform"); UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); auto &join_children = input->Children(); @@ -1156,13 +1156,13 @@ void MarkJoinToInnerJoin::Transform( SingleJoinToInnerJoin::SingleJoinToInnerJoin() { type_ = RuleType::MARK_JOIN_GET_TO_INNER_JOIN; - match_pattern = std::make_shared(OpType::LogicalSingleJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalSingleJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); } -int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1173,7 +1173,7 @@ int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, } bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1186,7 +1186,7 @@ bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, void SingleJoinToInnerJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("SingleJoinToInnerJoin::Transform"); UNUSED_ATTRIBUTE auto single_join = input->Op().As(); auto &join_children = input->Children(); @@ -1207,15 +1207,15 @@ void SingleJoinToInnerJoin::Transform( PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() { type_ = RuleType::PULL_FILTER_THROUGH_MARK_JOIN; - match_pattern = std::make_shared(OpType::LogicalMarkJoin); - match_pattern->AddChild(std::make_shared(OpType::Leaf)); - auto filter = std::make_shared(OpType::LogicalFilter); - filter->AddChild(std::make_shared(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalMarkJoin); + match_pattern->AddChild(std::make_shared>(OpType::Leaf)); + auto filter = std::make_shared>(OpType::LogicalFilter); + filter->AddChild(std::make_shared>(OpType::Leaf)); match_pattern->AddChild(filter); } -int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1226,7 +1226,7 @@ int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, } bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, - OptimizeContext *context) const { + OptimizeContext *context) const { (void)context; (void)plan; @@ -1241,7 +1241,7 @@ bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, void PullFilterThroughMarkJoin::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughMarkJoin::Transform"); UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); auto &join_children = input->Children(); @@ -1269,14 +1269,14 @@ void PullFilterThroughMarkJoin::Transform( PullFilterThroughAggregation::PullFilterThroughAggregation() { type_ = RuleType::PULL_FILTER_THROUGH_AGGREGATION; - auto filter = std::make_shared(OpType::LogicalFilter); - filter->AddChild(std::make_shared(OpType::Leaf)); - match_pattern = std::make_shared(OpType::LogicalAggregateAndGroupBy); + auto filter = std::make_shared>(OpType::LogicalFilter); + filter->AddChild(std::make_shared>(OpType::Leaf)); + match_pattern = std::make_shared>(OpType::LogicalAggregateAndGroupBy); match_pattern->AddChild(filter); } -int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable @@ -1287,7 +1287,7 @@ int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, } bool PullFilterThroughAggregation::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1302,7 +1302,7 @@ bool PullFilterThroughAggregation::Check( void PullFilterThroughAggregation::Transform( std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughAggregation::Transform"); auto &memo = context->metadata->memo; auto &filter_expr = input->Children()[0]; diff --git a/src/optimizer/stats/child_stats_deriver.cpp b/src/optimizer/stats/child_stats_deriver.cpp index d320547915c..0fbf2720d99 100644 --- a/src/optimizer/stats/child_stats_deriver.cpp +++ b/src/optimizer/stats/child_stats_deriver.cpp @@ -20,9 +20,9 @@ namespace peloton { namespace optimizer { using std::vector; -vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, +vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo) { + Memo *memo) { required_cols_ = required_cols; gexpr_ = gexpr; memo_ = memo; diff --git a/src/optimizer/stats/stats_calculator.cpp b/src/optimizer/stats/stats_calculator.cpp index d086938a817..815e309290b 100644 --- a/src/optimizer/stats/stats_calculator.cpp +++ b/src/optimizer/stats/stats_calculator.cpp @@ -26,8 +26,8 @@ namespace peloton { namespace optimizer { -void StatsCalculator::CalculateStats(GroupExpression *gexpr, - ExprSet required_cols, Memo *memo, +void StatsCalculator::CalculateStats(GroupExpression *gexpr, + ExprSet required_cols, Memo *memo, concurrency::TransactionContext *txn) { gexpr_ = gexpr; memo_ = memo; diff --git a/test/include/optimizer/mock_task.h b/test/include/optimizer/mock_task.h index 32e5e1b8da4..7e18f458445 100644 --- a/test/include/optimizer/mock_task.h +++ b/test/include/optimizer/mock_task.h @@ -20,10 +20,10 @@ namespace peloton { namespace optimizer { namespace test { -class MockTask : public optimizer::OptimizerTask { +class MockTask : public optimizer::OptimizerTask { public: MockTask() - : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} + : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} MOCK_METHOD0(execute, void()); }; diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 23f520596dc..9868cfa924e 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -132,8 +132,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { optimizer.GetMetadata().memo.InsertExpression( optimizer.GetMetadata().MakeGroupExpression(parent_join), true); - OptimizeContext *root_context = - new OptimizeContext(&(optimizer.GetMetadata()), nullptr); + OptimizeContext *root_context = + new OptimizeContext(&(optimizer.GetMetadata()), nullptr); EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); @@ -227,8 +227,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { optimizer.GetMetadata().memo.InsertExpression( optimizer.GetMetadata().MakeGroupExpression(parent_join), true); - OptimizeContext *root_context = - new OptimizeContext(&(optimizer.GetMetadata()), nullptr); + OptimizeContext *root_context = + new OptimizeContext(&(optimizer.GetMetadata()), nullptr); EXPECT_EQ(left_leaf, parent_join->Children()[0]->Children()[0]); EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index f1ffd6add66..f9fd843b3b3 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -49,7 +49,9 @@ using namespace optimizer; class OptimizerTests : public PelotonTest { protected: - GroupExpression *GetSingleGroupExpression(Memo &memo, GroupExpression *expr, + GroupExpression *GetSingleGroupExpression( + Memo &memo, + GroupExpression *expr, int child_group_idx) { auto group = memo.GetGroupByID(expr->GetChildGroupId(child_group_idx)); EXPECT_EQ(1, group->GetLogicalExpressions().size()); @@ -343,19 +345,19 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); bind_node_visitor.BindNameToNode(parse_tree); - std::shared_ptr gexpr = + std::shared_ptr> gexpr = optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::shared_ptr> head_gexpr = + std::make_shared>(Operator(), child_groups); - std::shared_ptr root_context = - std::make_shared(&(optimizer.GetMetadata()), nullptr); + std::shared_ptr> root_context = + std::make_shared>(&(optimizer.GetMetadata()), nullptr); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); optimizer.GetMetadata().SetTaskPool(task_stack.get()); - task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, + task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); while (!task_stack->Empty()) { @@ -430,19 +432,19 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { auto bind_node_visitor = binder::BindNodeVisitor(txn, DEFAULT_DB_NAME); bind_node_visitor.BindNameToNode(parse_tree); - std::shared_ptr gexpr = + std::shared_ptr> gexpr = optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::shared_ptr> head_gexpr = + std::make_shared>(Operator(), child_groups); - std::shared_ptr root_context = - std::make_shared(&(optimizer.GetMetadata()), nullptr); + std::shared_ptr> root_context = + std::make_shared>(&(optimizer.GetMetadata()), nullptr); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); optimizer.GetMetadata().SetTaskPool(task_stack.get()); - task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, + task_stack->Push(new TopDownRewrite(gexpr->GetGroupID(), root_context, RewriteRuleSetName::PREDICATE_PUSH_DOWN)); while (!task_stack->Empty()) { @@ -486,14 +488,14 @@ TEST_F(OptimizerTests, ExecuteTaskStackTest) { optimizer::Optimizer optimizer; const int root_group_id = 0; const auto root_group = - new Group(root_group_id, std::unordered_set()); + new Group(root_group_id, std::unordered_set()); optimizer.GetMetadata().memo.Groups().emplace_back(root_group); auto required_prop = std::make_shared(PropertySet()); - auto root_context = std::make_shared( + auto root_context = std::make_shared>( &(optimizer.GetMetadata()), required_prop); auto task_stack = - std::unique_ptr(new OptimizerTaskStack()); + std::unique_ptr>(new OptimizerTaskStack()); auto &timer = optimizer.GetMetadata().timer; // Insert tasks into task stack From 1be43be8f47a8a3c0806fe170d5b16484e986360 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Mon, 1 Apr 2019 19:24:11 -0400 Subject: [PATCH 02/14] Enabled rewriting of a single rule (constant = constant), check tests. Possibly annoying problems w.r.t Peloton/terrier: (1) Use of unique_ptr/raw pointer as opposed to shared_ptr in AbstractExpression (2) AbstractExpression equality comparison method Additional components needed: - Dynamic/template/strategy rule evaluation (particularly comparison) - Repeated/multi-level application of rules - Layer to convert from memo -> AbstractExpression - Some refactoring w.r.t templated code - Better AbsExpr_Container/Expression indirection layer (intended to present a similar interface exposed by Operator/OperatorExpression relied upon by core logic) - Proper memory management strategy (tightly coupled to problem #1) --- src/include/common/internal_types.h | 4 + src/include/optimizer/absexpr_expression.h | 146 +++++++++++++++++++ src/include/optimizer/rewriter.h | 50 +++++++ src/include/optimizer/rule.h | 3 +- src/include/optimizer/rule_rewrite.h | 38 +++++ src/optimizer/binding.cpp | 50 ++++++- src/optimizer/group.cpp | 2 + src/optimizer/group_expression.cpp | 2 + src/optimizer/memo.cpp | 42 +++++- src/optimizer/optimizer_task.cpp | 30 ++++ src/optimizer/pattern.cpp | 1 + src/optimizer/rewriter.cpp | 154 +++++++++++++++++++++ src/optimizer/rule.cpp | 14 +- src/optimizer/rule_rewrite.cpp | 99 +++++++++++++ test/optimizer/rewriter_test.cpp | 121 ++++++++++++++++ 15 files changed, 744 insertions(+), 12 deletions(-) create mode 100644 src/include/optimizer/absexpr_expression.h create mode 100644 src/include/optimizer/rewriter.h create mode 100644 src/include/optimizer/rule_rewrite.h create mode 100644 src/optimizer/rewriter.cpp create mode 100644 src/optimizer/rule_rewrite.cpp create mode 100644 test/optimizer/rewriter_test.cpp diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 96b45f9e42b..21de29a080e 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1383,6 +1383,10 @@ enum class RuleType : uint32_t { PULL_FILTER_THROUGH_MARK_JOIN, PULL_FILTER_THROUGH_AGGREGATION, + // AST rewrite rules (logical -> logical) + // Removes ConstantValueExpression = ConstantValueExpression + COMP_EQUALITY_ELIMINATION, + // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h new file mode 100644 index 00000000000..1e712f41daa --- /dev/null +++ b/src/include/optimizer/absexpr_expression.h @@ -0,0 +1,146 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_expression.h +// +// Identification: src/include/optimizer/absexpr_expression.h +// +//===----------------------------------------------------------------------===// + +#pragma once + +// AbstractExpression Definition +#include "expression/abstract_expression.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +// (TODO): rethink the AbsExpr_Container/Expression approach in comparion to abstract +// Most of the core rule/optimizer code relies on the concept of an Operator / +// OperatorExpression and the interface that the two functions respectively expose. +// +// The annoying part is that an AbstractExpression blends together an Operator +// and OperatorExpression. Second part, the AbstractExpression does not export the +// correct interface that the rest of the system depends on. +// +// As an extreme level of simplification (sort of hacky), an AbsExpr_Container is +// analogous to Operator and wraps a single AbstractExpression node. AbsExpr_Expression +// is analogous to OperatorExpression. +// +// AbsExpr_Container does *not* handle memory correctly w.r.t internal instantiations +// from Rule transformation. This is since Peloton itself mixes unique_ptrs and +// hands out raw pointers which makes adding a shared_ptr here extremely problematic. +// terrier uses only shared_ptr when dealing with AbstractExpression trees. + +class AbsExpr_Container { + public: + AbsExpr_Container(); + + AbsExpr_Container(const expression::AbstractExpression *expr) { + node = expr; + } + + // Return operator type + ExpressionType GetType() const { + if (IsDefined()) { + return node->GetExpressionType(); + } + return ExpressionType::INVALID; + } + + const expression::AbstractExpression *GetExpr() const { + return node; + } + + // Operator contains Logical node + bool IsLogical() const { + return true; + } + + // Operator contains Physical node + bool IsPhysical() const { + return false; + } + + std::string GetName() const { + if (IsDefined()) { + return node->GetExpressionName(); + } + + return "Undefined"; + } + + hash_t Hash() const { + if (IsDefined()) { + return node->Hash(); + } + return 0; + } + + bool operator==(const AbsExpr_Container &r) { + if (IsDefined() && r.IsDefined()) { + // (TODO): need a better way to determine deep equality + + // NOTE: + // Without proper equality determinations, the groups will + // not be assigned correctly. Arguably, terrier does this + // better because a blind ExactlyEquals on different types + // of ConstantValueExpression under Peloton will crash! + + // For now, just return (false). + // I don't anticipate this will affect correctness, just + // performance, since duplicate trees will have to evaluated + // over and over again, rather than being able to "borrow" + // a previous tree's rewrite. + // + // Probably not worth to create a "validator" since porting + // this to terrier anyways (?). == does not check Value + // so it's broken. ExactlyEqual requires precondition checking. + return false; + } else if (!IsDefined() && !r.IsDefined()) { + return true; + } + return false; + } + + // Operator contains physical or logical operator node + bool IsDefined() const { + return node != nullptr; + } + + private: + const expression::AbstractExpression *node; +}; + +class AbsExpr_Expression { + public: + AbsExpr_Expression(AbsExpr_Container op): op(op) {}; + + void PushChild(std::shared_ptr op) { + children.push_back(op); + } + + void PopChild() { + children.pop_back(); + } + + const std::vector> &Children() const { + return children; + } + + const AbsExpr_Container &Op() const { + return op; + } + + private: + AbsExpr_Container op; + std::vector> children; +}; + +} // namespace optimizer +} // namespace peloton + diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h new file mode 100644 index 00000000000..c57fe25cf2b --- /dev/null +++ b/src/include/optimizer/rewriter.h @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.h +// +// Identification: src/include/optimizer/rewriter.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "expression/abstract_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/absexpr_expression.h" + +namespace peloton { +namespace optimizer { + +class Rewriter { + + public: + Rewriter(const Rewriter &) = delete; + Rewriter &operator=(const Rewriter &) = delete; + Rewriter(Rewriter &&) = delete; + Rewriter &operator=(Rewriter &&) = delete; + + Rewriter(); + + expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); + void Reset(); + + OptimizerMetadata &GetMetadata() { return metadata_; } + + std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + + private: + void ExecuteTaskStack(OptimizerTaskStack &task_stack); + void RewriteLoop(int root_group_id); + std::shared_ptr> ConvertTree(const expression::AbstractExpression *expr); + OptimizerMetadata metadata_; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index b6f85a4c085..b7681433405 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -113,7 +113,8 @@ struct RuleWithPromise { enum class RewriteRuleSetName : uint32_t { PREDICATE_PUSH_DOWN = 0, - UNNEST_SUBQUERY + UNNEST_SUBQUERY, + COMPARATOR_ELIMINATION }; /** diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h new file mode 100644 index 00000000000..c14d7e1d336 --- /dev/null +++ b/src/include/optimizer/rule_rewrite.h @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rule_rewrite.h +// +// Identification: src/include/optimizer/rule_rewrite.h +// +// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/rule.h" +#include "optimizer/absexpr_expression.h" + +#include + +namespace peloton { +namespace optimizer { + +class ComparatorElimination: public Rule { + public: + ComparatorElimination(); + + int Promise(GroupExpression *group_expr, + OptimizeContext *context) const override; + + bool Check(std::shared_ptr plan, + OptimizeContext *context) const override; + + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index e0091d2d0b1..bab57c490d2 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -15,6 +15,7 @@ #include "common/logger.h" #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -38,9 +39,41 @@ GroupBindingIterator::GroupBindingIterator( template bool GroupBindingIterator::HasNext() { + //(TODO): refactor this and specialization to reduce duplicated code + if (current_iterator_) { + // Check if still have bindings in current item + if (!current_iterator_->HasNext()) { + current_iterator_.reset(nullptr); + current_item_index_++; + } + } + + if (current_iterator_ == nullptr) { + // Keep checking item iterators until we find a match + while (current_item_index_ < num_group_items_) { + current_iterator_.reset(new GroupExprBindingIterator( + this->memo_, + target_group_->GetLogicalExpressions()[current_item_index_].get(), + pattern_)); + + if (current_iterator_->HasNext()) { + break; + } + + current_iterator_.reset(nullptr); + current_item_index_++; + } + } + + std::cout << "Is there a group bind: " << (current_iterator_ != nullptr) << "\n"; + return current_iterator_ != nullptr; +} + +// Specialization +template <> +bool GroupBindingIterator::HasNext() { LOG_TRACE("HasNext"); - //(TODO): GroupBindingIterator::HasNext() probably needs specialization if (pattern_->Type() == OpType::Leaf) { return current_item_index_ == 0; } @@ -56,7 +89,7 @@ bool GroupBindingIterator::HasNext() { if (current_iterator_ == nullptr) { // Keep checking item iterators until we find a match while (current_item_index_ < num_group_items_) { - current_iterator_.reset(new GroupExprBindingIterator( + current_iterator_.reset(new GroupExprBindingIterator( this->memo_, target_group_->GetLogicalExpressions()[current_item_index_].get(), pattern_)); @@ -75,10 +108,16 @@ bool GroupBindingIterator::HasNext() { template std::shared_ptr GroupBindingIterator::Next() { - //(TODO): GroupBindingIterator::Next() probably needs specialization + std::cout << "Fetching next iterator\n"; + return current_iterator_->Next(); +} + +// Specialization +template <> +std::shared_ptr GroupBindingIterator::Next() { if (pattern_->Type() == OpType::Leaf) { current_item_index_ = num_group_items_; - return std::make_shared(LeafOperator::make(group_id_)); + return std::make_shared(LeafOperator::make(group_id_)); } return current_iterator_->Next(); } @@ -189,5 +228,8 @@ std::shared_ptr GroupExprBindingIterator; template class GroupExprBindingIterator; +template class GroupBindingIterator; +template class GroupExprBindingIterator; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 5f248a415db..7de90a31a31 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -12,6 +12,7 @@ #include "optimizer/group.h" #include "optimizer/operator_expression.h" +#include "optimizer/absexpr_expression.h" #include "common/logger.h" @@ -171,6 +172,7 @@ bool Group::HasColumnStats(std::string column_ // Explicitly instantiate template class Group; +template class Group; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 08c88897e6b..98540606558 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "common/internal_types.h" +#include "optimizer/absexpr_expression.h" #include "optimizer/group_expression.h" #include "optimizer/group.h" #include "optimizer/rule.h" @@ -112,6 +113,7 @@ bool GroupExpression::HasRuleExplored(Rule; +template class GroupExpression; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index 5f86f988fa3..2988dd16b89 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -14,6 +14,7 @@ #include "optimizer/memo.h" #include "optimizer/operators.h" #include "optimizer/stats/stats_calculator.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -29,10 +30,15 @@ Memo::Memo() {} //===--------------------------------------------------------------------===// template GroupID Memo::AddNewGroup(std::shared_ptr> gexpr) { - //(TODO): handle general case/AbstractExpressions (void)gexpr; - PELOTON_ASSERT(0); - return 0; + + GroupID new_group_id = groups_.size(); + // Find out the table alias that this group represents + std::unordered_set table_aliases; + + groups_.emplace_back( + new Group(new_group_id, std::move(table_aliases))); + return new_group_id; } template <> @@ -81,9 +87,32 @@ GroupExpression *MemoSetGroupID((*it)->GetGroupID()); + std::cout << "Same Group discovered..\n"; + return *it; + } else { + group_expressions_.insert(gexpr.get()); + // New expression, so try to insert into an existing group or + // create a new group if none specified + GroupID group_id; + if (target_group == UNDEFINED_GROUP) { + group_id = AddNewGroup(gexpr); + } else { + group_id = target_group; + } + + Group *group = GetGroupByID(group_id); + group->AddExpression(gexpr, enforced); + + std::cout << "Inserted into new group...size(): " << group_expressions_.size() << "\n"; + return gexpr.get(); + } } // Specialization for Memo::InsertExpression due to OpType @@ -154,6 +183,7 @@ const std::string Memo::GetInfo() const { // Explicitly instantiate template template class Memo; +template class Memo; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index 8026828ee2d..1263561b38e 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -18,6 +18,7 @@ #include "optimizer/child_property_deriver.h" #include "optimizer/stats/stats_calculator.h" #include "optimizer/stats/child_stats_deriver.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -429,6 +430,14 @@ void TopDownRewrite::execute() { if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); + + // (TODO): verify correctness + // Check whether rule actually can be applied + // as opposed to a structural level test + if (!r.rule->Check(before, this->context_.get())) { + continue; + } + std::vector> after; r.rule->Transform(before, after, this->context_.get()); @@ -485,14 +494,27 @@ void BottomUpRewrite::execute() { std::sort(valid_rules.begin(), valid_rules.end(), std::greater>()); + std::cout << "Rule pass starting\n"; for (auto &r : valid_rules) { GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, r.rule->GetMatchPattern()); if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); + + std::cout << "Structural match found\n"; + + // (TODO): verify correctness + // Check whether rule actually can be applied + // as opposed to a structural level test + if (!r.rule->Check(before, this->context_.get())) { + continue; + } + + std::cout << "Rule integrity check passed\n"; std::vector> after; r.rule->Transform(before, after, this->context_.get()); + std::cout << "Rule Transformation conducted\n"; // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -504,11 +526,16 @@ void BottomUpRewrite::execute() { this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); this->PushTask( new BottomUpRewrite(group_id_, this->context_, rule_set_name_, false)); + + std::cout << "Rewrote expression overwrote!\n"; + std::cout << "Rule Pass ended, starting again\n"; return; } } cur_group_expr->SetRuleExplored(r.rule); } + + std::cout << "Rule Pass ended\n"; } @@ -516,5 +543,8 @@ void BottomUpRewrite::execute() { template class TopDownRewrite; template class BottomUpRewrite; +template class TopDownRewrite; +template class BottomUpRewrite; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/pattern.cpp b/src/optimizer/pattern.cpp index 03ab3858a0d..23b976888cf 100644 --- a/src/optimizer/pattern.cpp +++ b/src/optimizer/pattern.cpp @@ -33,6 +33,7 @@ OperatorType Pattern::Type() const { return _type; } // Explicitly instantiate template class Pattern; +template class Pattern; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp new file mode 100644 index 00000000000..b4cbd3fd0d4 --- /dev/null +++ b/src/optimizer/rewriter.cpp @@ -0,0 +1,154 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.cpp +// +// Identification: src/optimizer/rewriter.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include + +#include "optimizer/optimizer.h" +#include "optimizer/rewriter.h" +#include "common/exception.h" + +#include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/operator_visitor.h" +#include "optimizer/optimize_context.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/rule.h" +#include "optimizer/rule_impls.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/absexpr_expression.h" +#include "expression/abstract_expression.h" +#include "expression/constant_value_expression.h" + +using std::vector; +using std::unordered_map; +using std::shared_ptr; +using std::unique_ptr; +using std::move; +using std::pair; +using std::make_shared; + +namespace peloton { +namespace optimizer { + +using OptimizerMetadataTemplate = OptimizerMetadata; + +using OptimizeContextTemplate = OptimizeContext; + +using OptimizerTaskStackTemplate = OptimizerTaskStack; + +using TopDownRewriteTemplate = TopDownRewrite; + +using BottomUpRewriteTemplate = BottomUpRewrite; + +using GroupExpressionTemplate = GroupExpression; + +using GroupTemplate = Group; + +Rewriter::Rewriter() : metadata_(nullptr) { + metadata_ = OptimizerMetadataTemplate(nullptr); +} + +void Rewriter::RewriteLoop(int root_group_id) { + std::shared_ptr root_context = + std::make_shared(&metadata_, nullptr); + auto task_stack = + std::unique_ptr(new OptimizerTaskStackTemplate()); + metadata_.SetTaskPool(task_stack.get()); + + // Perform rewrite first + task_stack->Push(new BottomUpRewriteTemplate(root_group_id, root_context, RewriteRuleSetName::COMPARATOR_ELIMINATION, false)); + + ExecuteTaskStack(*task_stack); +} + +expression::AbstractExpression* Rewriter::RewriteExpression(const expression::AbstractExpression *expr) { + // (TODO): convert AbstractExpression to AbsExpr_Expression... + // This is needed in order to provide template classes the correct interface. + // This should probably be better abstracted away. + std::shared_ptr gexpr = ConvertTree(expr); + std::cout << "Converted tree to internal data structures\n"; + + GroupID root_id = gexpr->GetGroupID(); + RewriteLoop(root_id); + std::cout << "Performed rewrite loop pass\n"; + + // (TODO): rebuild AbstractExpression tree from memo + // The real strategy is very similar to Optimizer::ChooseBestPlan + // It should be possible to use the Children stored in GroupExpression + // to recursively pull from memo_ until a GroupExpression where + // GetChildrenGroupsSize() == 0 (which indicates the leaf). + + // For now, this just returns the top level node + GroupTemplate* group = metadata_.memo.GetGroupByID(root_id); + std::vector> exprs = group->GetLogicalExpressions(); + + PELOTON_ASSERT(exprs.size() > 0); + std::cout << "Final logical expressions retrieved\n"; + + // Take the first one + gexpr = exprs[0]; + PELOTON_ASSERT(gexpr->GetChildrenGroupsSize() == 0); + + // (TODO): build a layer which can go from AbsExpr_Container -> new AbstractExpression + // (TODO): build a layer which can go from AbsExpr_Expression -> new AbstractExpression + // right now this is just hard-coded which is bad + PELOTON_ASSERT(gexpr->Op().GetType() == ExpressionType::VALUE_CONSTANT); + auto casted = static_cast(gexpr->Op().GetExpr()); + auto rebuilt = new expression::ConstantValueExpression(casted->GetValue()); + std::cout << "Rebuilt expression\n"; + + Reset(); + std::cout << "Reset the rewriter\n"; + return rebuilt; +} + +void Rewriter::Reset() { + metadata_ = OptimizerMetadataTemplate(nullptr); +} + +std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { + + // (TODO): need to think about how memory management would work w.r.t Peloton/terrier + // for now, this just directly wraps each AbstractExpression in a AbsExpr_Container + // which is then wrapped in an AbsExpr_Expression to provide the same Operator/OperatorExpression + // interface that is relied upon by the rest of the code base. + + auto container = AbsExpr_Container(expr); + auto exp = std::make_shared(container); + for (size_t i = 0; i < expr->GetChildrenSize(); i++) { + exp->PushChild(ConvertToAbsExpr(expr->GetChild(i))); + } + return exp; +} + +std::shared_ptr Rewriter::ConvertTree( + const expression::AbstractExpression *expr) { + std::cout << "Entered Rewriter::ConvertTree\n"; + + std::shared_ptr exp = ConvertToAbsExpr(expr); + std::cout << "Converted to AbsExpr_Expression\n"; + + std::shared_ptr gexpr; + metadata_.RecordTransformedExpression(exp, gexpr); + std::cout << "Initial loaded into memo\n"; + return gexpr; +} + +void Rewriter::ExecuteTaskStack(OptimizerTaskStackTemplate &task_stack) { + // Iterate through the task stack + while (!task_stack.Empty()) { + auto task = task_stack.Pop(); + task->execute(); + } +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8d39ab3b94b..47014b5b2ae 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -12,6 +12,8 @@ #include "optimizer/rule_impls.h" #include "optimizer/group_expression.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/rule_rewrite.h" namespace peloton { namespace optimizer { @@ -46,8 +48,14 @@ int Rule::Promise( template RuleSet::RuleSet() { - //(TODO): handle general/AbstractExpression case PELOTON_ASSERT(0); + // should never be invoked +} + +template <> +RuleSet::RuleSet() { + AddRewriteRule(RewriteRuleSetName::COMPARATOR_ELIMINATION, + new ComparatorElimination()); } template <> @@ -88,5 +96,9 @@ RuleSet::RuleSet() { new PullFilterThroughAggregation()); } +// Explicitly instantiate +template class Rule; +template class Rule; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp new file mode 100644 index 00000000000..24c8259e317 --- /dev/null +++ b/src/optimizer/rule_rewrite.cpp @@ -0,0 +1,99 @@ +#include + +#include "catalog/column_catalog.h" +#include "catalog/index_catalog.h" +#include "catalog/table_catalog.h" +#include "optimizer/operators.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/properties.h" +#include "optimizer/rule_rewrite.h" +#include "optimizer/util.h" +#include "type/value_factory.h" + +namespace peloton { +namespace optimizer { + +ComparatorElimination::ComparatorElimination() { + type_ = RuleType::COMP_EQUALITY_ELIMINATION; + + match_pattern = std::make_shared>(ExpressionType::COMPARE_EQUAL); + auto left = std::make_shared>(ExpressionType::VALUE_CONSTANT); + auto right = std::make_shared>(ExpressionType::VALUE_CONSTANT); + match_pattern->AddChild(left); + match_pattern->AddChild(right); +} + +int ComparatorElimination::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { + (void)group_expr; + (void)context; + + //(TODO): is this correct, proceed to structural binding? + std::cout << "Promise hit\n"; + return 1; +} + +bool ComparatorElimination::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + + std::cout << "Check hit\n"; + + //(TODO): perform checking more gracefully + // Technically, if structure matches, rule should always be applied + PELOTON_ASSERT(plan != nullptr); + PELOTON_ASSERT(plan->Children().size() == 2); + PELOTON_ASSERT(plan->Op().GetType() == ExpressionType::COMPARE_EQUAL); + + // Verify the structure of the tree is correct + auto left = plan->Children()[0]; + auto right = plan->Children()[1]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Op().GetType() == ExpressionType::VALUE_CONSTANT); + PELOTON_ASSERT(right->Children().size() == 0); + PELOTON_ASSERT(right->Op().GetType() == ExpressionType::VALUE_CONSTANT); + + return true; +} + +void ComparatorElimination::Transform(std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + (void)transformed; + (void)context; + + // (TODO): create a wrapper for evaluating ConstantValue relations + + // Extract the AbstractExpression through indirection layer + auto left = input->Children()[0]->Op().GetExpr(); + auto right = input->Children()[1]->Op().GetExpr(); + auto lv = static_cast(left); + auto rv = static_cast(right); + lv = const_cast(lv); + rv = const_cast(rv); + + // Get the Value from ConstantValueExpression + auto lvalue = lv->GetValue(); + auto rvalue = rv->GetValue(); + + // Need to check type equality to prevent assertion failure + // This is only a Peloton issue (terrier checks type for you) + // (TODO): perform checking through a class/strategy + bool is_equal = (lvalue.GetTypeId() == rvalue.GetTypeId()) && + (lv->ExactlyEquals(*rv)); + + // Create the transformed expression + type::Value val = type::ValueFactory::GetBooleanValue(is_equal); + auto eq = new expression::ConstantValueExpression(val); + auto cnt = AbsExpr_Container(eq); + auto shared = std::make_shared(cnt); + + // (TODO): figure out how to free these expressions + // (TODO): Terrier uses shared_ptr but Peloton has this + // awkward mixture of raw pointers and unique_ptr + transformed.push_back(shared); +} +} // namespace optimizer +} // namespace peloton diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp new file mode 100644 index 00000000000..48857b891c0 --- /dev/null +++ b/test/optimizer/rewriter_test.cpp @@ -0,0 +1,121 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// operator_test.cpp +// +// Identification: test/optimizer/operator_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include "common/harness.h" + +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class RewriterTests : public PelotonTest {}; + +TEST_F(RewriterTests, ConvertAbsExpr) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(1); + type::Value rightValue = type::ValueFactory::GetIntegerValue(2); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + + auto absexpr = rewriter->ConvertToAbsExpr(common); + EXPECT_TRUE(absexpr != nullptr); + EXPECT_TRUE(absexpr->Op().GetType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(absexpr->Children().size() == 2); + + auto lefta = absexpr->Children()[0]; + auto righta = absexpr->Children()[1]; + EXPECT_TRUE(lefta != nullptr && righta != nullptr); + EXPECT_TRUE(lefta->Op().GetType() == righta->Op().GetType()); + EXPECT_TRUE(lefta->Op().GetType() == ExpressionType::VALUE_CONSTANT); + + auto left_cve = static_cast(lefta->Op().GetExpr()); + auto right_cve = static_cast(righta->Op().GetExpr()); + EXPECT_TRUE(left_cve == left); + EXPECT_TRUE(right_cve == right); + + // Try applying the rule + ComparatorElimination rule; + EXPECT_TRUE(rule.Check(absexpr, nullptr) == true); + + std::vector> transform; + rule.Transform(absexpr, transform, nullptr); + EXPECT_TRUE(transform.size() == 1); + + delete rewriter; + delete common; + + auto tr_expr = transform[0]; + EXPECT_TRUE(tr_expr != nullptr); + EXPECT_TRUE(tr_expr->Op().GetType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(tr_expr->Children().size() == 0); + + auto tr_cve = static_cast(tr_expr->Op().GetExpr()); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(tr_cve->GetValue()) == false); + + // (TODO): hack to fix the memory leak bubbled from Transform() + delete tr_cve; +} + +TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(3); + type::Value rightValue = type::ValueFactory::GetIntegerValue(2); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { + type::Value leftValue = type::ValueFactory::GetIntegerValue(4); + type::Value rightValue = type::ValueFactory::GetIntegerValue(4); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} +} // namespace test +} // namespace peloton From f4d4e8fba44fce79e73585569c46ba4c0b21ad36 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Fri, 5 Apr 2019 12:53:56 -0400 Subject: [PATCH 03/14] Refactoring + full tree rebuilding (at least 2 levels work) What still doesn't work/don't care about yet/not done - proper memory management (terrier uses shared_ptr anyways) - other 1-level rewrites, multi-layer rewrites, other expr rewrites - how can we define a grammar to programmatically create these rewrites? (the one we have is way too static...) - in relation to logical equivalence: (1) how do we generate logically equivalent expressions: - multi-pass using generating rules (similar to ApplyRule) OR - from Pattern A, generate logically equivalent set of patterns P OR - transform input expression to match certain specification OR - ??? (2) what operators do we support translating? - probably (a AND b) ====> (b AND a) - probably (a OR b) ====> (b OR a) - probably (a = b) ====> (b = a) - maybe more??? (3) do we want multi level translations? - i.e (a AND b) AND c ====> (a AND (b AND c)) - what order do we do these in? May have to modify these operations: - Some assertions in TopDownRewrite/BottomUpRewrite w.r.t to the iterator - Possibly binding.cpp / optimizer_metadata.h / optimizer_task.cpp Issues still pending: - Comparing Values (Matt email/discussion) - r.rule->Check (terrier issue #332) --- src/include/optimizer/absexpr_expression.h | 37 ++++++++++ src/include/optimizer/binding.h | 3 + src/include/optimizer/memo.h | 5 ++ src/include/optimizer/rewriter.h | 1 + src/include/optimizer/rule_rewrite.h | 7 ++ src/optimizer/binding.cpp | 37 ++-------- src/optimizer/group.cpp | 7 +- src/optimizer/memo.cpp | 61 ++++++---------- src/optimizer/optimizer_task.cpp | 13 +--- src/optimizer/rewriter.cpp | 64 ++++++++-------- src/optimizer/rule.cpp | 9 ++- src/optimizer/rule_rewrite.cpp | 20 ++--- test/optimizer/rewriter_test.cpp | 85 ++++++++++++++++++++++ 13 files changed, 216 insertions(+), 133 deletions(-) diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h index 1e712f41daa..745881ccfb0 100644 --- a/src/include/optimizer/absexpr_expression.h +++ b/src/include/optimizer/absexpr_expression.h @@ -12,6 +12,9 @@ // AbstractExpression Definition #include "expression/abstract_expression.h" +#include "expression/conjunction_expression.h" +#include "expression/comparison_expression.h" +#include "expression/constant_value_expression.h" #include #include @@ -112,6 +115,40 @@ class AbsExpr_Container { return node != nullptr; } + //(TODO): fix memory management once go to terrier + expression::AbstractExpression *Rebuild(std::vector children) { + switch (GetType()) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_LIKE: + case ExpressionType::COMPARE_NOTLIKE: + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_DISTINCT_FROM: { + PELOTON_ASSERT(children.size() == 2); + return new expression::ComparisonExpression(GetType(), children[0], children[1]); + } + case ExpressionType::CONJUNCTION_AND: + case ExpressionType::CONJUNCTION_OR: { + PELOTON_ASSERT(children.size() == 2); + return new expression::ConjunctionExpression(GetType(), children[0], children[1]); + } + case ExpressionType::VALUE_CONSTANT: { + PELOTON_ASSERT(children.size() == 0); + auto cve = static_cast(node); + return new expression::ConstantValueExpression(cve->GetValue()); + } + default: { + int type = static_cast(GetType()); + LOG_ERROR("Unimplemented Rebuild() for %d found", type); + return nullptr; + } + } + } + private: const expression::AbstractExpression *node; }; diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index 616bda57782..57756b07b83 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -63,6 +63,9 @@ class GroupBindingIterator : public BindingIterator *target_group_; size_t num_group_items_; + // Internal function for HasNext() + bool HasNextBinding(); + size_t current_item_index_; std::unique_ptr> current_iterator_; }; diff --git a/src/include/optimizer/memo.h b/src/include/optimizer/memo.h index be67f961c9a..4bc77009de8 100644 --- a/src/include/optimizer/memo.h +++ b/src/include/optimizer/memo.h @@ -91,6 +91,11 @@ class Memo { private: GroupID AddNewGroup(std::shared_ptr> gexpr); + // Internal InsertExpression function + GroupExpression* InsertExpr( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced); + // The group owns the group expressions, not the memo std::unordered_set*, GExprPtrHash, diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h index c57fe25cf2b..796b10f7779 100644 --- a/src/include/optimizer/rewriter.h +++ b/src/include/optimizer/rewriter.h @@ -40,6 +40,7 @@ class Rewriter { std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); private: + expression::AbstractExpression* RebuildExpression(int root_group); void ExecuteTaskStack(OptimizerTaskStack &task_stack); void RewriteLoop(int root_group_id); std::shared_ptr> ConvertTree(const expression::AbstractExpression *expr); diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index c14d7e1d336..fe0f2b829bf 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -20,6 +20,13 @@ namespace peloton { namespace optimizer { +/* Rules are applied from high to low priority */ +enum class RulePriority : int { + HIGH = 3, + MEDIUM = 2, + LOW = 1 +}; + class ComparatorElimination: public Rule { public: ComparatorElimination(); diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index bab57c490d2..2975dce336c 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -38,8 +38,7 @@ GroupBindingIterator::GroupBindingIterator( } template -bool GroupBindingIterator::HasNext() { - //(TODO): refactor this and specialization to reduce duplicated code +bool GroupBindingIterator::HasNextBinding() { if (current_iterator_) { // Check if still have bindings in current item if (!current_iterator_->HasNext()) { @@ -65,10 +64,14 @@ bool GroupBindingIterator::HasNext() { } } - std::cout << "Is there a group bind: " << (current_iterator_ != nullptr) << "\n"; return current_iterator_ != nullptr; } +template +bool GroupBindingIterator::HasNext() { + return HasNextBinding(); +} + // Specialization template <> bool GroupBindingIterator::HasNext() { @@ -78,37 +81,11 @@ bool GroupBindingIterator::HasNext() { return current_item_index_ == 0; } - if (current_iterator_) { - // Check if still have bindings in current item - if (!current_iterator_->HasNext()) { - current_iterator_.reset(nullptr); - current_item_index_++; - } - } - - if (current_iterator_ == nullptr) { - // Keep checking item iterators until we find a match - while (current_item_index_ < num_group_items_) { - current_iterator_.reset(new GroupExprBindingIterator( - this->memo_, - target_group_->GetLogicalExpressions()[current_item_index_].get(), - pattern_)); - - if (current_iterator_->HasNext()) { - break; - } - - current_iterator_.reset(nullptr); - current_item_index_++; - } - } - - return current_iterator_ != nullptr; + return HasNextBinding(); } template std::shared_ptr GroupBindingIterator::Next() { - std::cout << "Fetching next iterator\n"; return current_iterator_->Next(); } diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 7de90a31a31..99f9efd9171 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -33,7 +33,12 @@ void Group::AddExpression( std::shared_ptr> expr, bool enforced) { - //(TODO): rethink how separation works with AbstractExpressions + // Additional assertion checks for AddExpression() with AST rewriting + if (std::is_same::value) { + PELOTON_ASSERT(!enforced); + PELOTON_ASSERT(!expr->Op().IsPhysical()); + } + // Do duplicate detection expr->SetGroupID(id_); if (enforced) diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index 2988dd16b89..691cc9d3832 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -74,27 +74,13 @@ GroupID Memo::AddNewGroup(std::shared_ptr -GroupExpression *Memo::InsertExpression( - std::shared_ptr> gexpr, - bool enforced) { +GroupExpression* Memo::InsertExpr( + std::shared_ptr> gexpr, + GroupID target_group, bool enforced) { - return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); -} - -template -GroupExpression *Memo::InsertExpression( - std::shared_ptr> gexpr, - GroupID target_group, - bool enforced) { - - //(TODO): refactor this with the specialization auto it = group_expressions_.find(gexpr.get()); - std::cout << "group_expressions_.size(): " << group_expressions_.size() << "\n"; - std::cout << "InsertExpression (" << gexpr << "," << target_group << ")\n"; - if (it != group_expressions_.end()) { gexpr->SetGroupID((*it)->GetGroupID()); - std::cout << "Same Group discovered..\n"; return *it; } else { group_expressions_.insert(gexpr.get()); @@ -109,18 +95,34 @@ GroupExpression *Memo *group = GetGroupByID(group_id); group->AddExpression(gexpr, enforced); - - std::cout << "Inserted into new group...size(): " << group_expressions_.size() << "\n"; return gexpr.get(); } } +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + bool enforced) { + + return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); +} + +template +GroupExpression *Memo::InsertExpression( + std::shared_ptr> gexpr, + GroupID target_group, + bool enforced) { + + return InsertExpr(gexpr, target_group, enforced); +} + // Specialization for Memo::InsertExpression due to OpType template <> GroupExpression *Memo::InsertExpression( std::shared_ptr> gexpr, GroupID target_group, bool enforced) { + // If leaf, then just return if (gexpr->Op().GetType() == OpType::Leaf) { const LeafOperator *leaf = gexpr->Op().As(); @@ -130,26 +132,7 @@ GroupExpression *MemoSetGroupID((*it)->GetGroupID()); - return *it; - } else { - group_expressions_.insert(gexpr.get()); - // New expression, so try to insert into an existing group or - // create a new group if none specified - GroupID group_id; - if (target_group == UNDEFINED_GROUP) { - group_id = AddNewGroup(gexpr); - } else { - group_id = target_group; - } - Group *group = GetGroupByID(group_id); - group->AddExpression(gexpr, enforced); - return gexpr.get(); - } + return InsertExpr(gexpr, target_group, enforced); } template diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index 1263561b38e..d8fc17b7e27 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -431,7 +431,7 @@ void TopDownRewrite::execute() { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - // (TODO): verify correctness + // (TODO): pending terrier issue #332 // Check whether rule actually can be applied // as opposed to a structural level test if (!r.rule->Check(before, this->context_.get())) { @@ -494,7 +494,6 @@ void BottomUpRewrite::execute() { std::sort(valid_rules.begin(), valid_rules.end(), std::greater>()); - std::cout << "Rule pass starting\n"; for (auto &r : valid_rules) { GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, r.rule->GetMatchPattern()); @@ -502,19 +501,15 @@ void BottomUpRewrite::execute() { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::cout << "Structural match found\n"; - - // (TODO): verify correctness + // (TODO): pending terrier issue #332 // Check whether rule actually can be applied // as opposed to a structural level test if (!r.rule->Check(before, this->context_.get())) { continue; } - std::cout << "Rule integrity check passed\n"; std::vector> after; r.rule->Transform(before, after, this->context_.get()); - std::cout << "Rule Transformation conducted\n"; // Rewrite rule should provide at most 1 expression PELOTON_ASSERT(after.size() <= 1); @@ -527,15 +522,11 @@ void BottomUpRewrite::execute() { this->PushTask( new BottomUpRewrite(group_id_, this->context_, rule_set_name_, false)); - std::cout << "Rewrote expression overwrote!\n"; - std::cout << "Rule Pass ended, starting again\n"; return; } } cur_group_expr->SetRuleExplored(r.rule); } - - std::cout << "Rule Pass ended\n"; } diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp index b4cbd3fd0d4..d23d998e51d 100644 --- a/src/optimizer/rewriter.cpp +++ b/src/optimizer/rewriter.cpp @@ -69,45 +69,45 @@ void Rewriter::RewriteLoop(int root_group_id) { ExecuteTaskStack(*task_stack); } +expression::AbstractExpression* Rewriter::RebuildExpression(int root) { + auto cur_group = metadata_.memo.GetGroupByID(root); + auto exprs = cur_group->GetLogicalExpressions(); + + // (TODO): what should we do if exprs.size() > 1? + PELOTON_ASSERT(exprs.size() > 0); + auto expr = exprs[0]; + + std::vector child_groups = expr->GetChildGroupIDs(); + std::vector child_exprs; + for (auto group : child_groups) { + // Build children first + expression::AbstractExpression *child = RebuildExpression(group); + PELOTON_ASSERT(child != nullptr); + + child_exprs.push_back(child); + } + + AbsExpr_Container c = expr->Op(); + return c.Rebuild(child_exprs); +} + expression::AbstractExpression* Rewriter::RewriteExpression(const expression::AbstractExpression *expr) { - // (TODO): convert AbstractExpression to AbsExpr_Expression... + // (TODO): do we need to actually convert to a wrapper? // This is needed in order to provide template classes the correct interface. // This should probably be better abstracted away. std::shared_ptr gexpr = ConvertTree(expr); - std::cout << "Converted tree to internal data structures\n"; + LOG_DEBUG("Converted tree to internal data structures"); GroupID root_id = gexpr->GetGroupID(); RewriteLoop(root_id); - std::cout << "Performed rewrite loop pass\n"; - - // (TODO): rebuild AbstractExpression tree from memo - // The real strategy is very similar to Optimizer::ChooseBestPlan - // It should be possible to use the Children stored in GroupExpression - // to recursively pull from memo_ until a GroupExpression where - // GetChildrenGroupsSize() == 0 (which indicates the leaf). - - // For now, this just returns the top level node - GroupTemplate* group = metadata_.memo.GetGroupByID(root_id); - std::vector> exprs = group->GetLogicalExpressions(); - - PELOTON_ASSERT(exprs.size() > 0); - std::cout << "Final logical expressions retrieved\n"; - - // Take the first one - gexpr = exprs[0]; - PELOTON_ASSERT(gexpr->GetChildrenGroupsSize() == 0); + LOG_DEBUG("Performed rewrite loop pass"); - // (TODO): build a layer which can go from AbsExpr_Container -> new AbstractExpression - // (TODO): build a layer which can go from AbsExpr_Expression -> new AbstractExpression - // right now this is just hard-coded which is bad - PELOTON_ASSERT(gexpr->Op().GetType() == ExpressionType::VALUE_CONSTANT); - auto casted = static_cast(gexpr->Op().GetExpr()); - auto rebuilt = new expression::ConstantValueExpression(casted->GetValue()); - std::cout << "Rebuilt expression\n"; + expression::AbstractExpression *expr_tree = RebuildExpression(root_id); + LOG_DEBUG("Rebuilt expression tree from memo table"); Reset(); - std::cout << "Reset the rewriter\n"; - return rebuilt; + LOG_DEBUG("Reset the rewriter"); + return expr_tree; } void Rewriter::Reset() { @@ -116,7 +116,7 @@ void Rewriter::Reset() { std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { - // (TODO): need to think about how memory management would work w.r.t Peloton/terrier + // (TODO): fix memory management once we get to terrier // for now, this just directly wraps each AbstractExpression in a AbsExpr_Container // which is then wrapped in an AbsExpr_Expression to provide the same Operator/OperatorExpression // interface that is relied upon by the rest of the code base. @@ -131,14 +131,10 @@ std::shared_ptr Rewriter::ConvertToAbsExpr(const expression: std::shared_ptr Rewriter::ConvertTree( const expression::AbstractExpression *expr) { - std::cout << "Entered Rewriter::ConvertTree\n"; std::shared_ptr exp = ConvertToAbsExpr(expr); - std::cout << "Converted to AbsExpr_Expression\n"; - std::shared_ptr gexpr; metadata_.RecordTransformedExpression(exp, gexpr); - std::cout << "Initial loaded into memo\n"; return gexpr; } diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 47014b5b2ae..0d14104060d 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -23,9 +23,10 @@ int Rule::Promise( GroupExpression *group_expr, OptimizeContext *context) const { - //(TODO): handle general/AbstractExpression case - PELOTON_ASSERT(group_expr); - PELOTON_ASSERT(context); + (void)group_expr; + (void)context; + + LOG_ERROR("Rule::Promise for rewrite engine not implemented!"); PELOTON_ASSERT(0); return 0; } @@ -48,8 +49,8 @@ int Rule::Promise( template RuleSet::RuleSet() { + LOG_ERROR("Must invoke specialization of RuleSet constructor"); PELOTON_ASSERT(0); - // should never be invoked } template <> diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index 24c8259e317..88d23092c31 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -28,10 +28,7 @@ int ComparatorElimination::Promise(GroupExpression *context) const { (void)group_expr; (void)context; - - //(TODO): is this correct, proceed to structural binding? - std::cout << "Promise hit\n"; - return 1; + return static_cast(RulePriority::HIGH); } bool ComparatorElimination::Check(std::shared_ptr plan, @@ -39,15 +36,12 @@ bool ComparatorElimination::Check(std::shared_ptr plan, (void)context; (void)plan; - std::cout << "Check hit\n"; - - //(TODO): perform checking more gracefully - // Technically, if structure matches, rule should always be applied + // If any of these assertions fail, something is seriously wrong with GroupExprBinding + // Verify the structure of the tree is correct PELOTON_ASSERT(plan != nullptr); PELOTON_ASSERT(plan->Children().size() == 2); PELOTON_ASSERT(plan->Op().GetType() == ExpressionType::COMPARE_EQUAL); - // Verify the structure of the tree is correct auto left = plan->Children()[0]; auto right = plan->Children()[1]; PELOTON_ASSERT(left->Children().size() == 0); @@ -55,6 +49,7 @@ bool ComparatorElimination::Check(std::shared_ptr plan, PELOTON_ASSERT(right->Children().size() == 0); PELOTON_ASSERT(right->Op().GetType() == ExpressionType::VALUE_CONSTANT); + // Technically, if structure matches, rule should always be applied return true; } @@ -64,7 +59,7 @@ void ComparatorElimination::Transform(std::shared_ptr input, (void)transformed; (void)context; - // (TODO): create a wrapper for evaluating ConstantValue relations + // (TODO): create a wrapper for evaluating ConstantValue relations (pending email reply) // Extract the AbstractExpression through indirection layer auto left = input->Children()[0]->Op().GetExpr(); @@ -80,7 +75,6 @@ void ComparatorElimination::Transform(std::shared_ptr input, // Need to check type equality to prevent assertion failure // This is only a Peloton issue (terrier checks type for you) - // (TODO): perform checking through a class/strategy bool is_equal = (lvalue.GetTypeId() == rvalue.GetTypeId()) && (lv->ExactlyEquals(*rv)); @@ -90,9 +84,7 @@ void ComparatorElimination::Transform(std::shared_ptr input, auto cnt = AbsExpr_Container(eq); auto shared = std::make_shared(cnt); - // (TODO): figure out how to free these expressions - // (TODO): Terrier uses shared_ptr but Peloton has this - // awkward mixture of raw pointers and unique_ptr + // (TODO): figure out memory management once go to terrier (which use shared_ptr) transformed.push_back(shared); } } // namespace optimizer diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp index 48857b891c0..48c4b0420b9 100644 --- a/test/optimizer/rewriter_test.cpp +++ b/test/optimizer/rewriter_test.cpp @@ -117,5 +117,90 @@ TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { delete rewrote; } + +TEST_F(RewriterTests, SimpleEqualityTree) { + // [=] + // [=] [=] + // [4] [5] [3] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val5); + auto rb_left_child = new expression::ConstantValueExpression(val3); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +// (TODO): delete this test once more rewriting rules implemented +TEST_F(RewriterTests, SimpleJunctionPreserve) { + // [AND] + // [=] [=] + // [4] [5] [3] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val5); + auto rb_left_child = new expression::ConstantValueExpression(val3); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 2); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + + auto left = rewrote->GetChild(0); + auto right = rewrote->GetChild(1); + + EXPECT_TRUE(left != nullptr && right != nullptr); + EXPECT_TRUE(left->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(right->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto left_cast = dynamic_cast(left); + auto right_cast = dynamic_cast(right); + EXPECT_TRUE(left_cast->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(right_cast->GetValueType() == type::TypeId::BOOLEAN); + + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(left_cast->GetValue()) == false); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(right_cast->GetValue()) == true); + + delete rewrote; +} + } // namespace test } // namespace peloton From 209c46a689efbc7bc5df4d52cc6afed1ceb88e7e Mon Sep 17 00:00:00 2001 From: Newton Xie Date: Sun, 7 Apr 2019 14:31:14 -0400 Subject: [PATCH 04/14] Adding AbstractNode class. AbstractNode will provide interface for Operator and eventually AbstractExpressions as well. Note there are a few road blocks before the rest of the rewriter can be changed to cleanly use abstract classes: (1) Similarly abstract OperatorExpressions. (2) We will have to find a good place to hide OpType, which is currently an enum type (cannot be abstracted) and pervades the code base. This may be solved by abstracting at the group level, but will have to look into it. (3) Need to clean up and separate interfaces between AbstractNode, OperatorNode, and Operator classes. --- src/include/optimizer/abstract_node.h | 117 +++++++++++++++++++ src/include/optimizer/input_column_deriver.h | 4 +- src/include/optimizer/operator_node.h | 116 +++--------------- src/include/optimizer/operators.h | 38 +++--- src/optimizer/input_column_deriver.cpp | 4 +- src/optimizer/operator_node.cpp | 3 +- src/optimizer/operators.cpp | 76 ++++++------ 7 files changed, 198 insertions(+), 160 deletions(-) create mode 100644 src/include/optimizer/abstract_node.h diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h new file mode 100644 index 00000000000..6da97d6d227 --- /dev/null +++ b/src/include/optimizer/abstract_node.h @@ -0,0 +1,117 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// abstract_node.h +// +// Identification: src/include/optimizer/abstract_node.h +// +// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/property_set.h" +#include "util/hash_util.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +enum class OpType { + Undefined = 0, + // Special match operators + Leaf, + // Logical ops + Get, + LogicalExternalFileGet, + LogicalQueryDerivedGet, + LogicalProjection, + LogicalFilter, + LogicalMarkJoin, + LogicalDependentJoin, + LogicalSingleJoin, + InnerJoin, + LeftJoin, + RightJoin, + OuterJoin, + SemiJoin, + LogicalAggregateAndGroupBy, + LogicalInsert, + LogicalInsertSelect, + LogicalDelete, + LogicalUpdate, + LogicalLimit, + LogicalDistinct, + LogicalExportExternalFile, + // Separate between logical and physical ops + LogicalPhysicalDelimiter, + // Physical ops + DummyScan, /* Dummy Physical Op for SELECT without FROM*/ + SeqScan, + IndexScan, + ExternalFileScan, + QueryDerivedScan, + OrderBy, + PhysicalLimit, + Distinct, + InnerNLJoin, + LeftNLJoin, + RightNLJoin, + OuterNLJoin, + InnerHashJoin, + LeftHashJoin, + RightHashJoin, + OuterHashJoin, + Insert, + InsertSelect, + Delete, + Update, + Aggregate, + HashGroupBy, + SortGroupBy, + ExportExternalFile, +}; + +//===--------------------------------------------------------------------===// +// Abstract Node +//===--------------------------------------------------------------------===// +//TODO(ncx): dependence on OperatorVisitor +class OperatorVisitor; + +struct AbstractNode { + AbstractNode() {} + + ~AbstractNode() {} + + virtual void Accept(OperatorVisitor *v) const = 0; + + virtual std::string GetName() const = 0; + + // TODO(ncx): problematic dependence on OpType + virtual OpType GetType() const = 0; + + virtual bool IsLogical() const = 0; + + virtual bool IsPhysical() const = 0; + + virtual hash_t Hash() const { + OpType t = GetType(); + return HashUtil::Hash(&t); + } + + virtual bool operator==(const AbstractNode &r) { + return GetType() == r.GetType(); + } + + virtual bool IsDefined() const { return node != nullptr; } + + private: + std::shared_ptr node; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index ef66823bba0..4b218a9b99b 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -100,8 +100,8 @@ class InputColumnDeriver : public OperatorVisitor { * @brief Provide all tuple value expressions needed in the expression */ void ScanHelper(); - void AggregateHelper(const BaseOperatorNode *); - void JoinHelper(const BaseOperatorNode *op); + void AggregateHelper(const AbstractNode *); + void JoinHelper(const AbstractNode *op); /** * @brief Some operators, for example limit, directly pass down column diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index f870df330eb..da91e320e1f 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -12,6 +12,7 @@ #pragma once +#include "optimizer/abstract_node.h" #include "optimizer/property_set.h" #include "util/hash_util.h" @@ -21,103 +22,24 @@ namespace peloton { namespace optimizer { -enum class OpType { - Undefined = 0, - // Special match operators - Leaf, - // Logical ops - Get, - LogicalExternalFileGet, - LogicalQueryDerivedGet, - LogicalProjection, - LogicalFilter, - LogicalMarkJoin, - LogicalDependentJoin, - LogicalSingleJoin, - InnerJoin, - LeftJoin, - RightJoin, - OuterJoin, - SemiJoin, - LogicalAggregateAndGroupBy, - LogicalInsert, - LogicalInsertSelect, - LogicalDelete, - LogicalUpdate, - LogicalLimit, - LogicalDistinct, - LogicalExportExternalFile, - // Separate between logical and physical ops - LogicalPhysicalDelimiter, - // Physical ops - DummyScan, /* Dummy Physical Op for SELECT without FROM*/ - SeqScan, - IndexScan, - ExternalFileScan, - QueryDerivedScan, - OrderBy, - PhysicalLimit, - Distinct, - InnerNLJoin, - LeftNLJoin, - RightNLJoin, - OuterNLJoin, - InnerHashJoin, - LeftHashJoin, - RightHashJoin, - OuterHashJoin, - Insert, - InsertSelect, - Delete, - Update, - Aggregate, - HashGroupBy, - SortGroupBy, - ExportExternalFile, -}; - //===--------------------------------------------------------------------===// // Operator Node //===--------------------------------------------------------------------===// class OperatorVisitor; -struct BaseOperatorNode { - BaseOperatorNode() {} - - virtual ~BaseOperatorNode() {} - - virtual void Accept(OperatorVisitor *v) const = 0; - - virtual std::string GetName() const = 0; - - virtual OpType GetType() const = 0; - - virtual bool IsLogical() const = 0; - - virtual bool IsPhysical() const = 0; - - virtual std::vector RequiredInputProperties() const { - return {}; - } - - virtual hash_t Hash() const { - OpType t = GetType(); - return HashUtil::Hash(&t); - } - - virtual bool operator==(const BaseOperatorNode &r) { - return GetType() == r.GetType(); - } -}; - // Curiously recurring template pattern +// TODO(ncx): this templating would be nice to clean up template -struct OperatorNode : public BaseOperatorNode { +struct OperatorNode : public AbstractNode { + OperatorNode() {} + + virtual ~OperatorNode() {} + void Accept(OperatorVisitor *v) const; - std::string GetName() const { return name_; } + virtual std::string GetName() const { return name_; } - OpType GetType() const { return type_; } + virtual OpType GetType() const { return type_; } bool IsLogical() const; @@ -128,32 +50,26 @@ struct OperatorNode : public BaseOperatorNode { static OpType type_; }; -class Operator { +class Operator : public AbstractNode { public: Operator(); - Operator(BaseOperatorNode *node); + Operator(AbstractNode *node); - // Calls corresponding visitor to node void Accept(OperatorVisitor *v) const; - // Return name of operator std::string GetName() const; - // Return operator type OpType GetType() const; - // Operator contains Logical node bool IsLogical() const; - // Operator contains Physical node bool IsPhysical() const; hash_t Hash() const; bool operator==(const Operator &r); - // Operator contains physical or logical operator node bool IsDefined() const; template @@ -164,8 +80,12 @@ class Operator { return nullptr; } + static std::string name_; + + static OpType type_; + private: - std::shared_ptr node; + std::shared_ptr node; }; } // namespace optimizer @@ -174,8 +94,8 @@ class Operator { namespace std { template <> -struct hash { - typedef peloton::optimizer::BaseOperatorNode argument_type; +struct hash { + typedef peloton::optimizer::AbstractNode argument_type; typedef std::size_t result_type; result_type operator()(argument_type const &s) const { return s.Hash(); } }; diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index c8a8483d242..80ba32e94ef 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -55,7 +55,7 @@ class LogicalGet : public OperatorNode { std::shared_ptr table = nullptr, std::string alias = "", bool update = false); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -76,7 +76,7 @@ class LogicalExternalFileGet : public OperatorNode { std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -100,7 +100,7 @@ class LogicalQueryDerivedGet : public OperatorNode { std::shared_ptr> alias_to_expr_map); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -120,7 +120,7 @@ class LogicalFilter : public OperatorNode { static Operator make(std::vector &filter); std::vector predicates; - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; }; @@ -144,7 +144,7 @@ class LogicalDependentJoin : public OperatorNode { static Operator make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -160,7 +160,7 @@ class LogicalMarkJoin : public OperatorNode { static Operator make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -176,7 +176,7 @@ class LogicalSingleJoin : public OperatorNode { static Operator make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -192,7 +192,7 @@ class LogicalInnerJoin : public OperatorNode { static Operator make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -254,7 +254,7 @@ class LogicalAggregateAndGroupBy std::vector> &columns, std::vector &having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; std::vector> columns; @@ -346,7 +346,7 @@ class LogicalExportExternalFile static Operator make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -376,7 +376,7 @@ class PhysicalSeqScan : public OperatorNode { std::vector predicates, bool update); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -401,7 +401,7 @@ class PhysicalIndexScan : public OperatorNode { std::vector expr_type_list, std::vector value_list); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -430,7 +430,7 @@ class ExternalFileScan : public OperatorNode { std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -454,7 +454,7 @@ class QueryDerivedScan : public OperatorNode { std::shared_ptr> alias_to_expr_map); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -503,7 +503,7 @@ class PhysicalInnerNLJoin : public OperatorNode { std::vector> &left_keys, std::vector> &right_keys); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -553,7 +553,7 @@ class PhysicalInnerHashJoin : public OperatorNode { std::vector> &left_keys, std::vector> &right_keys); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -650,7 +650,7 @@ class PhysicalExportExternalFile static Operator make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -670,7 +670,7 @@ class PhysicalHashGroupBy : public OperatorNode { std::vector> columns, std::vector having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; std::vector> columns; @@ -686,7 +686,7 @@ class PhysicalSortGroupBy : public OperatorNode { std::vector> columns, std::vector having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; // TODO(boweic): use raw ptr std::vector> columns; diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index fdffb7e79a6..e764a0d902a 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -200,7 +200,7 @@ void InputColumnDeriver::ScanHelper() { output_cols, {}}; } -void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) { +void InputColumnDeriver::AggregateHelper(const AbstractNode *op) { ExprSet input_cols_set; ExprMap output_cols_map; oid_t output_col_idx = 0; @@ -269,7 +269,7 @@ void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) { output_cols, {input_cols}}; } -void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) { +void InputColumnDeriver::JoinHelper(const AbstractNode *op) { const vector *join_conds = nullptr; const vector> *left_keys = nullptr; const vector> *right_keys = diff --git a/src/optimizer/operator_node.cpp b/src/optimizer/operator_node.cpp index e262792e774..0785020a41d 100644 --- a/src/optimizer/operator_node.cpp +++ b/src/optimizer/operator_node.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "optimizer/abstract_node.h" #include "optimizer/operator_node.h" namespace peloton { @@ -20,7 +21,7 @@ namespace optimizer { //===--------------------------------------------------------------------===// Operator::Operator() : node(nullptr) {} -Operator::Operator(BaseOperatorNode *node) : node(node) {} +Operator::Operator(AbstractNode *node) : node(node) {} void Operator::Accept(OperatorVisitor *v) const { node->Accept(v); } diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 52cf83f9a8c..e6f8c0294d2 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -45,14 +45,14 @@ Operator LogicalGet::make(oid_t get_id, } hash_t LogicalGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalGet::operator==(const BaseOperatorNode &r) { +bool LogicalGet::operator==(const AbstractNode &r) { if (r.GetType() != OpType::Get) return false; const LogicalGet &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; @@ -80,7 +80,7 @@ Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, return Operator(get); } -bool LogicalExternalFileGet::operator==(const BaseOperatorNode &node) { +bool LogicalExternalFileGet::operator==(const AbstractNode &node) { if (node.GetType() != OpType::LogicalExternalFileGet) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && @@ -89,7 +89,7 @@ bool LogicalExternalFileGet::operator==(const BaseOperatorNode &node) { } hash_t LogicalExternalFileGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( @@ -116,7 +116,7 @@ Operator LogicalQueryDerivedGet::make( return Operator(get); } -bool LogicalQueryDerivedGet::operator==(const BaseOperatorNode &node) { +bool LogicalQueryDerivedGet::operator==(const AbstractNode &node) { if (node.GetType() != OpType::LogicalQueryDerivedGet) return false; const LogicalQueryDerivedGet &r = *static_cast(&node); @@ -124,7 +124,7 @@ bool LogicalQueryDerivedGet::operator==(const BaseOperatorNode &node) { } hash_t LogicalQueryDerivedGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); return hash; } @@ -139,13 +139,13 @@ Operator LogicalFilter::make(std::vector &filter) { } hash_t LogicalFilter::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalFilter::operator==(const BaseOperatorNode &r) { +bool LogicalFilter::operator==(const AbstractNode &r) { if (r.GetType() != OpType::LogicalFilter) return false; const LogicalFilter &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; @@ -182,13 +182,13 @@ Operator LogicalDependentJoin::make( } hash_t LogicalDependentJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalDependentJoin::operator==(const BaseOperatorNode &r) { +bool LogicalDependentJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::LogicalDependentJoin) return false; const LogicalDependentJoin &node = *static_cast(&r); @@ -217,13 +217,13 @@ Operator LogicalMarkJoin::make(std::vector &conditions) { } hash_t LogicalMarkJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalMarkJoin::operator==(const BaseOperatorNode &r) { +bool LogicalMarkJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::LogicalMarkJoin) return false; const LogicalMarkJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -251,13 +251,13 @@ Operator LogicalSingleJoin::make(std::vector &conditions) { } hash_t LogicalSingleJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { +bool LogicalSingleJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::LogicalSingleJoin) return false; const LogicalSingleJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -285,13 +285,13 @@ Operator LogicalInnerJoin::make(std::vector &conditions) { } hash_t LogicalInnerJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { +bool LogicalInnerJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::InnerJoin) return false; const LogicalInnerJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -368,7 +368,7 @@ Operator LogicalAggregateAndGroupBy::make( return Operator(group_by); } -bool LogicalAggregateAndGroupBy::operator==(const BaseOperatorNode &node) { +bool LogicalAggregateAndGroupBy::operator==(const AbstractNode &node) { if (node.GetType() != OpType::LogicalAggregateAndGroupBy) return false; const LogicalAggregateAndGroupBy &r = *static_cast(&node); @@ -381,7 +381,7 @@ bool LogicalAggregateAndGroupBy::operator==(const BaseOperatorNode &node) { } hash_t LogicalAggregateAndGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; @@ -470,7 +470,7 @@ Operator LogicalExportExternalFile::make(ExternalFileFormat format, return Operator(export_op); } -bool LogicalExportExternalFile::operator==(const BaseOperatorNode &node) { +bool LogicalExportExternalFile::operator==(const AbstractNode &node) { if (node.GetType() != OpType::LogicalExportExternalFile) return false; const auto &export_op = *static_cast(&node); @@ -480,7 +480,7 @@ bool LogicalExportExternalFile::operator==(const BaseOperatorNode &node) { } hash_t LogicalExportExternalFile::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( hash, HashUtil::HashBytes(file_name.data(), file_name.length())); @@ -516,7 +516,7 @@ Operator PhysicalSeqScan::make( return Operator(scan); } -bool PhysicalSeqScan::operator==(const BaseOperatorNode &r) { +bool PhysicalSeqScan::operator==(const AbstractNode &r) { if (r.GetType() != OpType::SeqScan) return false; const PhysicalSeqScan &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; @@ -528,7 +528,7 @@ bool PhysicalSeqScan::operator==(const BaseOperatorNode &r) { } hash_t PhysicalSeqScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); @@ -559,7 +559,7 @@ Operator PhysicalIndexScan::make( return Operator(scan); } -bool PhysicalIndexScan::operator==(const BaseOperatorNode &r) { +bool PhysicalIndexScan::operator==(const AbstractNode &r) { if (r.GetType() != OpType::IndexScan) return false; const PhysicalIndexScan &node = *static_cast(&r); // TODO: Should also check value list @@ -577,7 +577,7 @@ bool PhysicalIndexScan::operator==(const BaseOperatorNode &r) { } hash_t PhysicalIndexScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&index_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) @@ -601,7 +601,7 @@ Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, return Operator(get); } -bool ExternalFileScan::operator==(const BaseOperatorNode &node) { +bool ExternalFileScan::operator==(const AbstractNode &node) { if (node.GetType() != OpType::QueryDerivedScan) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && @@ -610,7 +610,7 @@ bool ExternalFileScan::operator==(const BaseOperatorNode &node) { } hash_t ExternalFileScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( @@ -637,14 +637,14 @@ Operator QueryDerivedScan::make( return Operator(get); } -bool QueryDerivedScan::operator==(const BaseOperatorNode &node) { +bool QueryDerivedScan::operator==(const AbstractNode &node) { if (node.GetType() != OpType::QueryDerivedScan) return false; const QueryDerivedScan &r = *static_cast(&node); return get_id == r.get_id; } hash_t QueryDerivedScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); return hash; } @@ -689,7 +689,7 @@ Operator PhysicalInnerNLJoin::make( } hash_t PhysicalInnerNLJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); for (auto &expr : right_keys) @@ -699,7 +699,7 @@ hash_t PhysicalInnerNLJoin::Hash() const { return hash; } -bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { +bool PhysicalInnerNLJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::InnerNLJoin) return false; const PhysicalInnerNLJoin &node = *static_cast(&r); @@ -766,7 +766,7 @@ Operator PhysicalInnerHashJoin::make( } hash_t PhysicalInnerHashJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); for (auto &expr : right_keys) @@ -776,7 +776,7 @@ hash_t PhysicalInnerHashJoin::Hash() const { return hash; } -bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { +bool PhysicalInnerHashJoin::operator==(const AbstractNode &r) { if (r.GetType() != OpType::InnerHashJoin) return false; const PhysicalInnerHashJoin &node = *static_cast(&r); @@ -891,7 +891,7 @@ Operator PhysicalExportExternalFile::make(ExternalFileFormat format, return Operator(export_op); } -bool PhysicalExportExternalFile::operator==(const BaseOperatorNode &node) { +bool PhysicalExportExternalFile::operator==(const AbstractNode &node) { if (node.GetType() != OpType::ExportExternalFile) return false; const auto &export_op = *static_cast(&node); @@ -901,7 +901,7 @@ bool PhysicalExportExternalFile::operator==(const BaseOperatorNode &node) { } hash_t PhysicalExportExternalFile::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( hash, HashUtil::HashBytes(file_name.data(), file_name.length())); @@ -923,7 +923,7 @@ Operator PhysicalHashGroupBy::make( return Operator(agg); } -bool PhysicalHashGroupBy::operator==(const BaseOperatorNode &node) { +bool PhysicalHashGroupBy::operator==(const AbstractNode &node) { if (node.GetType() != OpType::HashGroupBy) return false; const PhysicalHashGroupBy &r = *static_cast(&node); @@ -936,7 +936,7 @@ bool PhysicalHashGroupBy::operator==(const BaseOperatorNode &node) { } hash_t PhysicalHashGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; @@ -954,7 +954,7 @@ Operator PhysicalSortGroupBy::make( return Operator(agg); } -bool PhysicalSortGroupBy::operator==(const BaseOperatorNode &node) { +bool PhysicalSortGroupBy::operator==(const AbstractNode &node) { if (node.GetType() != OpType::SortGroupBy) return false; const PhysicalSortGroupBy &r = *static_cast(&node); @@ -967,7 +967,7 @@ bool PhysicalSortGroupBy::operator==(const BaseOperatorNode &node) { } hash_t PhysicalSortGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; From 17de3b9bd7898841ec0ede915243410ebfa90b17 Mon Sep 17 00:00:00 2001 From: Newton Xie Date: Fri, 26 Apr 2019 05:03:20 -0400 Subject: [PATCH 05/14] Using AbstractNode throughout optimizer. Abstract nodes were implemented in 209c46a. This is essentially just refactoring and plugging in abstract nodes throughout the optimizer. The abstract interface exposes OpType and ExpressionType for now, which ideally will be fixed later. Work remaining for abstracting OperatorExpression. --- src/include/optimizer/abstract_node.h | 23 ++- .../optimizer/cost_model/default_cost_model.h | 2 +- .../cost_model/postgres_cost_model.h | 4 +- .../optimizer/cost_model/trivial_cost_model.h | 4 +- src/include/optimizer/group_expression.h | 8 +- src/include/optimizer/operator_node.h | 32 ++-- src/include/optimizer/optimizer_metadata.h | 4 +- src/optimizer/binding.cpp | 7 +- src/optimizer/child_property_deriver.cpp | 2 +- src/optimizer/group.cpp | 10 +- src/optimizer/group_expression.cpp | 12 +- src/optimizer/input_column_deriver.cpp | 10 +- src/optimizer/memo.cpp | 10 +- src/optimizer/operator_node.cpp | 15 +- src/optimizer/operators.cpp | 139 +++++++++--------- src/optimizer/optimizer.cpp | 7 +- src/optimizer/optimizer_task.cpp | 10 +- src/optimizer/property_enforcer.cpp | 6 +- src/optimizer/rule.cpp | 4 +- src/optimizer/rule_impls.cpp | 10 +- src/optimizer/stats/child_stats_deriver.cpp | 2 +- src/optimizer/stats/stats_calculator.cpp | 2 +- test/optimizer/optimizer_test.cpp | 34 +++-- 23 files changed, 193 insertions(+), 164 deletions(-) diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h index 6da97d6d227..4a076751abc 100644 --- a/src/include/optimizer/abstract_node.h +++ b/src/include/optimizer/abstract_node.h @@ -83,7 +83,7 @@ enum class OpType { class OperatorVisitor; struct AbstractNode { - AbstractNode() {} + AbstractNode(AbstractNode *node) : node(node) {} ~AbstractNode() {} @@ -91,25 +91,36 @@ struct AbstractNode { virtual std::string GetName() const = 0; - // TODO(ncx): problematic dependence on OpType - virtual OpType GetType() const = 0; + // TODO(ncx): dependence on OpType and ExpressionType (ideally abstracted away) + virtual OpType GetOpType() const = 0; + + virtual ExpressionType GetExpType() const = 0; virtual bool IsLogical() const = 0; virtual bool IsPhysical() const = 0; virtual hash_t Hash() const { - OpType t = GetType(); + // TODO(ncx): hash should work for ExpressionType nodes + OpType t = GetOpType(); return HashUtil::Hash(&t); } virtual bool operator==(const AbstractNode &r) { - return GetType() == r.GetType(); + return GetOpType() == r.GetOpType() && GetExpType() == r.GetExpType(); } virtual bool IsDefined() const { return node != nullptr; } - private: + template + const T *As() const { + if (node && typeid(*node) == typeid(T)) { + return (const T *)node.get(); + } + return nullptr; + } + + protected: std::shared_ptr node; }; diff --git a/src/include/optimizer/cost_model/default_cost_model.h b/src/include/optimizer/cost_model/default_cost_model.h index a92cb091db7..d5b9be3c82b 100644 --- a/src/include/optimizer/cost_model/default_cost_model.h +++ b/src/include/optimizer/cost_model/default_cost_model.h @@ -34,7 +34,7 @@ class DefaultCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Op()->Accept(this); return output_cost_; } diff --git a/src/include/optimizer/cost_model/postgres_cost_model.h b/src/include/optimizer/cost_model/postgres_cost_model.h index 2632a247a39..87e9858cd06 100644 --- a/src/include/optimizer/cost_model/postgres_cost_model.h +++ b/src/include/optimizer/cost_model/postgres_cost_model.h @@ -39,7 +39,7 @@ class PostgresCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Op()->Accept(this); return output_cost_; }; @@ -279,4 +279,4 @@ class PostgresCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/cost_model/trivial_cost_model.h b/src/include/optimizer/cost_model/trivial_cost_model.h index 2c5994ee728..5e3d7fd195f 100644 --- a/src/include/optimizer/cost_model/trivial_cost_model.h +++ b/src/include/optimizer/cost_model/trivial_cost_model.h @@ -41,7 +41,7 @@ class TrivialCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Op()->Accept(this); return output_cost_; }; @@ -116,4 +116,4 @@ class TrivialCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index 303ebaf036e..6a25871cbb1 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -12,7 +12,7 @@ #pragma once -#include "optimizer/operator_node.h" +#include "optimizer/abstract_node.h" #include "optimizer/stats/stats.h" #include "optimizer/util.h" #include "optimizer/property_set.h" @@ -34,7 +34,7 @@ using GroupID = int32_t; //===--------------------------------------------------------------------===// class GroupExpression { public: - GroupExpression(Operator op, std::vector child_groups); + GroupExpression(std::shared_ptr node, std::vector child_groups); GroupID GetGroupID() const; @@ -46,7 +46,7 @@ class GroupExpression { GroupID GetChildGroupId(int child_idx) const; - Operator Op() const; + std::shared_ptr Op() const; double GetCost(std::shared_ptr& requirements) const; @@ -75,7 +75,7 @@ class GroupExpression { private: GroupID group_id; - Operator op; + std::shared_ptr node; std::vector child_groups; std::bitset(RuleType::NUM_RULES)> rule_mask_; bool stats_derived_; diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index da91e320e1f..3a26c6daaa4 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -28,18 +28,19 @@ namespace optimizer { class OperatorVisitor; // Curiously recurring template pattern -// TODO(ncx): this templating would be nice to clean up template struct OperatorNode : public AbstractNode { - OperatorNode() {} + OperatorNode() : AbstractNode(nullptr) {} virtual ~OperatorNode() {} void Accept(OperatorVisitor *v) const; - virtual std::string GetName() const { return name_; } + std::string GetName() const { return name_; } - virtual OpType GetType() const { return type_; } + OpType GetOpType() const { return op_type_; } + + ExpressionType GetExpType() const { return exp_type_; } bool IsLogical() const; @@ -47,7 +48,9 @@ struct OperatorNode : public AbstractNode { static std::string name_; - static OpType type_; + static OpType op_type_; + + static ExpressionType exp_type_; }; class Operator : public AbstractNode { @@ -60,7 +63,9 @@ class Operator : public AbstractNode { std::string GetName() const; - OpType GetType() const; + OpType GetOpType() const; + + ExpressionType GetExpType() const; bool IsLogical() const; @@ -71,21 +76,6 @@ class Operator : public AbstractNode { bool operator==(const Operator &r); bool IsDefined() const; - - template - const T *As() const { - if (node && typeid(*node) == typeid(T)) { - return (const T *)node.get(); - } - return nullptr; - } - - static std::string name_; - - static OpType type_; - - private: - std::shared_ptr node; }; } // namespace optimizer diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 3f33e3ee8b1..85782dd09bf 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -19,6 +19,8 @@ #include "optimizer/rule.h" #include "settings/settings_manager.h" +#include + namespace peloton { namespace catalog { class Catalog; @@ -58,7 +60,7 @@ class OptimizerMetadata { memo.InsertExpression(gexpr, false); child_groups.push_back(gexpr->GetGroupID()); } - return std::make_shared(expr->Op(), + return std::make_shared(std::make_shared(expr->Op()), std::move(child_groups)); } diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 9651ce8102c..ab6d7a03bd8 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -85,8 +85,9 @@ GroupExprBindingIterator::GroupExprBindingIterator( pattern_(pattern), first_(true), has_next_(false), - current_binding_(std::make_shared(gexpr->Op())) { - if (gexpr->Op().GetType() != pattern->Type()) { + // TODO(ncx): fix once OperatorExpression is abstracted + current_binding_(std::make_shared(*(Operator *)gexpr->Op().get())) { + if (gexpr->Op()->GetOpType() != pattern->Type()) { return; } @@ -100,7 +101,7 @@ GroupExprBindingIterator::GroupExprBindingIterator( LOG_TRACE( "Attempting to bind on group %d with expression of type %s, children " "size %lu", - gexpr->GetGroupID(), gexpr->Op().GetName().c_str(), child_groups.size()); + gexpr->GetGroupID(), gexpr->Op()->GetName().c_str(), child_groups.size()); // Find all bindings for children children_bindings_.resize(child_groups.size(), {}); diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index b432067fae1..9b8adebbb68 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -38,7 +38,7 @@ ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, output_.clear(); memo_ = memo; gexpr_ = gexpr; - gexpr->Op().Accept(this); + gexpr->Op()->Accept(this); return move(output_); } diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 673a7a1b8bd..33bc32484a0 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -31,7 +31,7 @@ void Group::AddExpression(std::shared_ptr expr, expr->SetGroupID(id_); if (enforced) enforced_exprs_.push_back(expr); - else if (expr->Op().IsPhysical()) + else if (expr->Op()->IsPhysical()) physical_expressions_.push_back(expr); else logical_expressions_.push_back(expr); @@ -40,7 +40,7 @@ void Group::AddExpression(std::shared_ptr expr, bool Group::SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties) { LOG_TRACE("Adding expression cost on group %d with op %s, req %s", - expr->GetGroupID(), expr->Op().GetName().c_str(), + expr->GetGroupID(), expr->Op()->GetName().c_str(), properties->ToString().c_str()); auto it = lowest_cost_expressions_.find(properties); if (it == lowest_cost_expressions_.end() || std::get<0>(it->second) > cost) { @@ -86,7 +86,7 @@ const std::string Group::GetInfo(int num_indent) const { for (auto expr : logical_expressions_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; + << expr->Op()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) @@ -102,7 +102,7 @@ const std::string Group::GetInfo(int num_indent) const { << "physical_expressions_: \n"; for (auto expr : physical_expressions_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; + << expr->Op()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) @@ -119,7 +119,7 @@ const std::string Group::GetInfo(int num_indent) const { << "enforced_exprs_: \n"; for (auto expr : enforced_exprs_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; + << expr->Op()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 498c949b583..5db0fb32a82 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -21,9 +21,9 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// -GroupExpression::GroupExpression(Operator op, std::vector child_groups) +GroupExpression::GroupExpression(std::shared_ptr node, std::vector child_groups) : group_id(UNDEFINED_GROUP), - op(op), + node(node), child_groups(child_groups), stats_derived_(false) {} @@ -43,7 +43,9 @@ GroupID GroupExpression::GetChildGroupId(int child_idx) const { return child_groups[child_idx]; } -Operator GroupExpression::Op() const { return op; } +std::shared_ptr GroupExpression::Op() const { + return std::shared_ptr(node); +} double GroupExpression::GetCost( std::shared_ptr &requirements) const { @@ -74,7 +76,7 @@ void GroupExpression::SetLocalHashTable( } hash_t GroupExpression::Hash() const { - size_t hash = op.Hash(); + size_t hash = node->Hash(); for (size_t i = 0; i < child_groups.size(); ++i) { hash = HashUtil::CombineHashes(hash, @@ -85,7 +87,7 @@ hash_t GroupExpression::Hash() const { } bool GroupExpression::operator==(const GroupExpression &r) { - return (op == r.Op()) && (child_groups == r.child_groups); + return (*node == *r.Op()) && (child_groups == r.child_groups); } void GroupExpression::SetRuleExplored(Rule *rule) { diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index e764a0d902a..7104485454e 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -43,7 +43,7 @@ InputColumnDeriver::DeriveInputColumns( gexpr_ = gexpr; required_cols_ = move(required_cols); memo_ = memo; - gexpr->Op().Accept(this); + gexpr->Op()->Accept(this); return move(output_input_cols_); } @@ -232,11 +232,11 @@ void InputColumnDeriver::AggregateHelper(const AbstractNode *op) { // TODO(boweic): do not use shared_ptr vector> groupby_cols; vector having_exprs; - if (op->GetType() == OpType::HashGroupBy) { + if (op->GetOpType() == OpType::HashGroupBy) { auto groupby = reinterpret_cast(op); groupby_cols = groupby->columns; having_exprs = groupby->having; - } else if (op->GetType() == OpType::SortGroupBy) { + } else if (op->GetOpType() == OpType::SortGroupBy) { auto groupby = reinterpret_cast(op); groupby_cols = groupby->columns; having_exprs = groupby->having; @@ -274,12 +274,12 @@ void InputColumnDeriver::JoinHelper(const AbstractNode *op) { const vector> *left_keys = nullptr; const vector> *right_keys = nullptr; - if (op->GetType() == OpType::InnerHashJoin) { + if (op->GetOpType() == OpType::InnerHashJoin) { auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); - } else if (op->GetType() == OpType::InnerNLJoin) { + } else if (op->GetOpType() == OpType::InnerNLJoin) { auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index ca68a52c1d0..69e1e8a54f4 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -31,8 +31,8 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced) { // If leaf, then just return - if (gexpr->Op().GetType() == OpType::Leaf) { - const LeafOperator *leaf = gexpr->Op().As(); + if (gexpr->Op()->GetOpType() == OpType::Leaf) { + const LeafOperator *leaf = gexpr->Op()->As(); PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == leaf->origin_group); gexpr->SetGroupID(leaf->origin_group); @@ -91,14 +91,14 @@ GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { GroupID new_group_id = groups_.size(); // Find out the table alias that this group represents std::unordered_set table_aliases; - auto op_type = gexpr->Op().GetType(); + auto op_type = gexpr->Op()->GetOpType(); if (op_type == OpType::Get) { // For base group, the table alias can get directly from logical get - const LogicalGet *logical_get = gexpr->Op().As(); + const LogicalGet *logical_get = gexpr->Op()->As(); table_aliases.insert(logical_get->table_alias); } else if (op_type == OpType::LogicalQueryDerivedGet) { const LogicalQueryDerivedGet *query_get = - gexpr->Op().As(); + gexpr->Op()->As(); table_aliases.insert(query_get->table_alias); } else { // For other groups, need to aggregate the table alias from children diff --git a/src/optimizer/operator_node.cpp b/src/optimizer/operator_node.cpp index 0785020a41d..38e03310b94 100644 --- a/src/optimizer/operator_node.cpp +++ b/src/optimizer/operator_node.cpp @@ -19,9 +19,9 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Operator //===--------------------------------------------------------------------===// -Operator::Operator() : node(nullptr) {} +Operator::Operator() : AbstractNode(nullptr) {} -Operator::Operator(AbstractNode *node) : node(node) {} +Operator::Operator(AbstractNode *node) : AbstractNode(node) {} void Operator::Accept(OperatorVisitor *v) const { node->Accept(v); } @@ -32,13 +32,20 @@ std::string Operator::GetName() const { return "Undefined"; } -OpType Operator::GetType() const { +OpType Operator::GetOpType() const { if (IsDefined()) { - return node->GetType(); + return node->GetOpType(); } return OpType::Undefined; } +ExpressionType Operator::GetExpType() const { + if (IsDefined()) { + return node->GetExpType(); + } + return ExpressionType::INVALID; +} + bool Operator::IsLogical() const { if (IsDefined()) { return node->IsLogical(); diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index e6f8c0294d2..f75b1159a42 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -53,7 +53,7 @@ hash_t LogicalGet::Hash() const { } bool LogicalGet::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::Get) return false; + if (r.GetOpType() != OpType::Get) return false; const LogicalGet &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -81,7 +81,7 @@ Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, } bool LogicalExternalFileGet::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::LogicalExternalFileGet) return false; + if (node.GetOpType() != OpType::LogicalExternalFileGet) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && file_name == get.file_name && delimiter == get.delimiter && @@ -117,7 +117,7 @@ Operator LogicalQueryDerivedGet::make( } bool LogicalQueryDerivedGet::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::LogicalQueryDerivedGet) return false; + if (node.GetOpType() != OpType::LogicalQueryDerivedGet) return false; const LogicalQueryDerivedGet &r = *static_cast(&node); return get_id == r.get_id; @@ -146,7 +146,7 @@ hash_t LogicalFilter::Hash() const { } bool LogicalFilter::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::LogicalFilter) return false; + if (r.GetOpType() != OpType::LogicalFilter) return false; const LogicalFilter &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -189,7 +189,7 @@ hash_t LogicalDependentJoin::Hash() const { } bool LogicalDependentJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::LogicalDependentJoin) return false; + if (r.GetOpType() != OpType::LogicalDependentJoin) return false; const LogicalDependentJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -224,7 +224,7 @@ hash_t LogicalMarkJoin::Hash() const { } bool LogicalMarkJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::LogicalMarkJoin) return false; + if (r.GetOpType() != OpType::LogicalMarkJoin) return false; const LogicalMarkJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -258,7 +258,7 @@ hash_t LogicalSingleJoin::Hash() const { } bool LogicalSingleJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::LogicalSingleJoin) return false; + if (r.GetOpType() != OpType::LogicalSingleJoin) return false; const LogicalSingleJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -292,7 +292,7 @@ hash_t LogicalInnerJoin::Hash() const { } bool LogicalInnerJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::InnerJoin) return false; + if (r.GetOpType() != OpType::InnerJoin) return false; const LogicalInnerJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -369,7 +369,7 @@ Operator LogicalAggregateAndGroupBy::make( } bool LogicalAggregateAndGroupBy::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::LogicalAggregateAndGroupBy) return false; + if (node.GetOpType() != OpType::LogicalAggregateAndGroupBy) return false; const LogicalAggregateAndGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -471,7 +471,7 @@ Operator LogicalExportExternalFile::make(ExternalFileFormat format, } bool LogicalExportExternalFile::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::LogicalExportExternalFile) return false; + if (node.GetOpType() != OpType::LogicalExportExternalFile) return false; const auto &export_op = *static_cast(&node); return (format == export_op.format && file_name == export_op.file_name && @@ -517,7 +517,7 @@ Operator PhysicalSeqScan::make( } bool PhysicalSeqScan::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::SeqScan) return false; + if (r.GetOpType() != OpType::SeqScan) return false; const PhysicalSeqScan &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -560,7 +560,7 @@ Operator PhysicalIndexScan::make( } bool PhysicalIndexScan::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::IndexScan) return false; + if (r.GetOpType() != OpType::IndexScan) return false; const PhysicalIndexScan &node = *static_cast(&r); // TODO: Should also check value list if (index_id != node.index_id || @@ -602,7 +602,7 @@ Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, } bool ExternalFileScan::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::QueryDerivedScan) return false; + if (node.GetOpType() != OpType::QueryDerivedScan) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && file_name == get.file_name && delimiter == get.delimiter && @@ -638,7 +638,7 @@ Operator QueryDerivedScan::make( } bool QueryDerivedScan::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::QueryDerivedScan) return false; + if (node.GetOpType() != OpType::QueryDerivedScan) return false; const QueryDerivedScan &r = *static_cast(&node); return get_id == r.get_id; } @@ -700,7 +700,7 @@ hash_t PhysicalInnerNLJoin::Hash() const { } bool PhysicalInnerNLJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::InnerNLJoin) return false; + if (r.GetOpType() != OpType::InnerNLJoin) return false; const PhysicalInnerNLJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || @@ -777,7 +777,7 @@ hash_t PhysicalInnerHashJoin::Hash() const { } bool PhysicalInnerHashJoin::operator==(const AbstractNode &r) { - if (r.GetType() != OpType::InnerHashJoin) return false; + if (r.GetOpType() != OpType::InnerHashJoin) return false; const PhysicalInnerHashJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || @@ -892,7 +892,7 @@ Operator PhysicalExportExternalFile::make(ExternalFileFormat format, } bool PhysicalExportExternalFile::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::ExportExternalFile) return false; + if (node.GetOpType() != OpType::ExportExternalFile) return false; const auto &export_op = *static_cast(&node); return (format == export_op.format && file_name == export_op.file_name && @@ -924,7 +924,7 @@ Operator PhysicalHashGroupBy::make( } bool PhysicalHashGroupBy::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::HashGroupBy) return false; + if (node.GetOpType() != OpType::HashGroupBy) return false; const PhysicalHashGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -955,7 +955,7 @@ Operator PhysicalSortGroupBy::make( } bool PhysicalSortGroupBy::operator==(const AbstractNode &node) { - if (node.GetType() != OpType::SortGroupBy) return false; + if (node.GetOpType() != OpType::SortGroupBy) return false; const PhysicalSortGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -1099,113 +1099,118 @@ std::string OperatorNode::name_ = //===--------------------------------------------------------------------===// template <> -OpType OperatorNode::type_ = OpType::Leaf; +OpType OperatorNode::op_type_ = OpType::Leaf; template <> -OpType OperatorNode::type_ = OpType::Get; +OpType OperatorNode::op_type_ = OpType::Get; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalExternalFileGet; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalQueryDerivedGet; template <> -OpType OperatorNode::type_ = OpType::LogicalFilter; +OpType OperatorNode::op_type_ = OpType::LogicalFilter; template <> -OpType OperatorNode::type_ = OpType::LogicalProjection; +OpType OperatorNode::op_type_ = OpType::LogicalProjection; template <> -OpType OperatorNode::type_ = OpType::LogicalMarkJoin; +OpType OperatorNode::op_type_ = OpType::LogicalMarkJoin; template <> -OpType OperatorNode::type_ = OpType::LogicalSingleJoin; +OpType OperatorNode::op_type_ = OpType::LogicalSingleJoin; template <> -OpType OperatorNode::type_ = OpType::LogicalDependentJoin; +OpType OperatorNode::op_type_ = OpType::LogicalDependentJoin; template <> -OpType OperatorNode::type_ = OpType::InnerJoin; +OpType OperatorNode::op_type_ = OpType::InnerJoin; template <> -OpType OperatorNode::type_ = OpType::LeftJoin; +OpType OperatorNode::op_type_ = OpType::LeftJoin; template <> -OpType OperatorNode::type_ = OpType::RightJoin; +OpType OperatorNode::op_type_ = OpType::RightJoin; template <> -OpType OperatorNode::type_ = OpType::OuterJoin; +OpType OperatorNode::op_type_ = OpType::OuterJoin; template <> -OpType OperatorNode::type_ = OpType::SemiJoin; +OpType OperatorNode::op_type_ = OpType::SemiJoin; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalAggregateAndGroupBy; template <> -OpType OperatorNode::type_ = OpType::LogicalInsert; +OpType OperatorNode::op_type_ = OpType::LogicalInsert; template <> -OpType OperatorNode::type_ = OpType::LogicalInsertSelect; +OpType OperatorNode::op_type_ = OpType::LogicalInsertSelect; template <> -OpType OperatorNode::type_ = OpType::LogicalUpdate; +OpType OperatorNode::op_type_ = OpType::LogicalUpdate; template <> -OpType OperatorNode::type_ = OpType::LogicalDelete; +OpType OperatorNode::op_type_ = OpType::LogicalDelete; template <> -OpType OperatorNode::type_ = OpType::LogicalDistinct; +OpType OperatorNode::op_type_ = OpType::LogicalDistinct; template <> -OpType OperatorNode::type_ = OpType::LogicalLimit; +OpType OperatorNode::op_type_ = OpType::LogicalLimit; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalExportExternalFile; template <> -OpType OperatorNode::type_ = OpType::DummyScan; +OpType OperatorNode::op_type_ = OpType::DummyScan; template <> -OpType OperatorNode::type_ = OpType::SeqScan; +OpType OperatorNode::op_type_ = OpType::SeqScan; template <> -OpType OperatorNode::type_ = OpType::IndexScan; +OpType OperatorNode::op_type_ = OpType::IndexScan; template <> -OpType OperatorNode::type_ = OpType::ExternalFileScan; +OpType OperatorNode::op_type_ = OpType::ExternalFileScan; template <> -OpType OperatorNode::type_ = OpType::QueryDerivedScan; +OpType OperatorNode::op_type_ = OpType::QueryDerivedScan; template <> -OpType OperatorNode::type_ = OpType::OrderBy; +OpType OperatorNode::op_type_ = OpType::OrderBy; template <> -OpType OperatorNode::type_ = OpType::Distinct; +OpType OperatorNode::op_type_ = OpType::Distinct; template <> -OpType OperatorNode::type_ = OpType::PhysicalLimit; +OpType OperatorNode::op_type_ = OpType::PhysicalLimit; template <> -OpType OperatorNode::type_ = OpType::InnerNLJoin; +OpType OperatorNode::op_type_ = OpType::InnerNLJoin; template <> -OpType OperatorNode::type_ = OpType::LeftNLJoin; +OpType OperatorNode::op_type_ = OpType::LeftNLJoin; template <> -OpType OperatorNode::type_ = OpType::RightNLJoin; +OpType OperatorNode::op_type_ = OpType::RightNLJoin; template <> -OpType OperatorNode::type_ = OpType::OuterNLJoin; +OpType OperatorNode::op_type_ = OpType::OuterNLJoin; template <> -OpType OperatorNode::type_ = OpType::InnerHashJoin; +OpType OperatorNode::op_type_ = OpType::InnerHashJoin; template <> -OpType OperatorNode::type_ = OpType::LeftHashJoin; +OpType OperatorNode::op_type_ = OpType::LeftHashJoin; template <> -OpType OperatorNode::type_ = OpType::RightHashJoin; +OpType OperatorNode::op_type_ = OpType::RightHashJoin; template <> -OpType OperatorNode::type_ = OpType::OuterHashJoin; +OpType OperatorNode::op_type_ = OpType::OuterHashJoin; template <> -OpType OperatorNode::type_ = OpType::Insert; +OpType OperatorNode::op_type_ = OpType::Insert; template <> -OpType OperatorNode::type_ = OpType::InsertSelect; +OpType OperatorNode::op_type_ = OpType::InsertSelect; template <> -OpType OperatorNode::type_ = OpType::Delete; +OpType OperatorNode::op_type_ = OpType::Delete; template <> -OpType OperatorNode::type_ = OpType::Update; +OpType OperatorNode::op_type_ = OpType::Update; template <> -OpType OperatorNode::type_ = OpType::HashGroupBy; +OpType OperatorNode::op_type_ = OpType::HashGroupBy; template <> -OpType OperatorNode::type_ = OpType::SortGroupBy; +OpType OperatorNode::op_type_ = OpType::SortGroupBy; template <> -OpType OperatorNode::type_ = OpType::Aggregate; +OpType OperatorNode::op_type_ = OpType::Aggregate; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::ExportExternalFile; + +//===--------------------------------------------------------------------===// +template +ExpressionType OperatorNode::exp_type_ = ExpressionType::INVALID; + //===--------------------------------------------------------------------===// template bool OperatorNode::IsLogical() const { - return type_ < OpType::LogicalPhysicalDelimiter; + return op_type_ < OpType::LogicalPhysicalDelimiter; } template bool OperatorNode::IsPhysical() const { - return type_ > OpType::LogicalPhysicalDelimiter; + return op_type_ > OpType::LogicalPhysicalDelimiter; } template <> diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 83bcadde4de..c12205c3c35 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -327,7 +327,7 @@ const std::string Optimizer::GetOperatorInfo( auto gexpr = group->GetBestExpression(required_props); os << std::endl << StringUtil::Indent(num_indent) << "operator name: " - << gexpr->Op().GetName().c_str(); + << gexpr->Op()->GetName().c_str(); vector child_groups = gexpr->GetChildGroupIDs(); auto required_input_props = gexpr->GetInputProperties(required_props); @@ -352,7 +352,7 @@ unique_ptr Optimizer::ChooseBestPlan( auto gexpr = group->GetBestExpression(required_props); LOG_TRACE("Choosing best plan for group %d with op %s", gexpr->GetGroupID(), - gexpr->Op().GetName().c_str()); + gexpr->Op()->GetName().c_str()); vector child_groups = gexpr->GetChildGroupIDs(); auto required_input_props = gexpr->GetInputProperties(required_props); @@ -383,8 +383,9 @@ unique_ptr Optimizer::ChooseBestPlan( } // Derive root plan + // TODO(ncx): fix once OperatorExpression is abstracted shared_ptr op = - make_shared(gexpr->Op()); + make_shared(*(Operator *)gexpr->Op().get()); PlanGenerator generator; auto plan = generator.ConvertOpExpression(op, required_props, required_cols, diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index e1cfac5643d..cd3f7e5c778 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -31,8 +31,12 @@ void OptimizerTask::ConstructValidRules( std::vector &valid_rules) { for (auto &rule : rules) { // Check if we can apply the rule + // TODO(ncx): replace after pattern fix + // bool root_pattern_mismatch = + // group_expr->Op()->GetOpType() != rule->GetMatchPattern()->OpType() + // || group_expr->Op()->GetExpType() != rule->GetMatchPattern()->ExpType(); bool root_pattern_mismatch = - group_expr->Op().GetType() != rule->GetMatchPattern()->Type(); + group_expr->Op()->GetOpType() != rule->GetMatchPattern()->Type(); bool already_explored = group_expr->HasRuleExplored(rule.get()); bool child_pattern_mismatch = group_expr->GetChildrenGroupsSize() != @@ -97,7 +101,7 @@ void OptimizeExpression::execute() { std::sort(valid_rules.begin(), valid_rules.end()); LOG_DEBUG("OptimizeExpression::execute() op %d, valid rules : %lu", - static_cast(group_expr_->Op().GetType()), valid_rules.size()); + static_cast(group_expr_->Op()->GetOpType()), valid_rules.size()); // Apply rule for (auto &r : valid_rules) { PushTask(new ApplyRule(group_expr_, r.rule, context_)); @@ -187,7 +191,7 @@ void ApplyRule::execute() { if (context_->metadata->RecordTransformedExpression( new_expr, new_gexpr, group_expr_->GetGroupID())) { // A new group expression is generated - if (new_gexpr->Op().IsLogical()) { + if (new_gexpr->Op()->IsLogical()) { // Derive stats for the *logical expression* PushTask(new DeriveStats(new_gexpr.get(), ExprSet{}, context_)); if (explore_only) { diff --git a/src/optimizer/property_enforcer.cpp b/src/optimizer/property_enforcer.cpp index 834cf9a76d7..4069eb5716a 100644 --- a/src/optimizer/property_enforcer.cpp +++ b/src/optimizer/property_enforcer.cpp @@ -33,13 +33,15 @@ void PropertyEnforcer::Visit(const PropertyColumns *) { void PropertyEnforcer::Visit(const PropertySort *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalOrderBy::make(), child_groups); + std::make_shared( + std::make_shared(PhysicalOrderBy::make()), child_groups); } void PropertyEnforcer::Visit(const PropertyDistinct *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalDistinct::make(), child_groups); + std::make_shared( + std::make_shared(PhysicalOrderBy::make()), child_groups); } void PropertyEnforcer::Visit(const PropertyLimit *) {} diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8c72ed17fa8..4357d81ae2f 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -18,9 +18,11 @@ namespace optimizer { int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; + // TODO(ncx): replace after pattern fix + // auto root_type = match_pattern->OpType(); auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { return 0; } if (IsPhysical()) return PHYS_PROMISE; diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 8574e00f337..3138492f7d4 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -102,7 +102,7 @@ void InnerJoinAssociativity::Transform( auto parent_join = input->Op().As(); std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - PELOTON_ASSERT(children[0]->Op().GetType() == OpType::InnerJoin); + PELOTON_ASSERT(children[0]->Op().GetOpType() == OpType::InnerJoin); PELOTON_ASSERT(children[0]->Children().size() == 2); auto child_join = children[0]->Op().As(); auto left = children[0]->Children()[0]; @@ -1115,7 +1115,7 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); @@ -1166,7 +1166,7 @@ int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); @@ -1219,7 +1219,7 @@ int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); @@ -1280,7 +1280,7 @@ int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); diff --git a/src/optimizer/stats/child_stats_deriver.cpp b/src/optimizer/stats/child_stats_deriver.cpp index d320547915c..5831dfdaffb 100644 --- a/src/optimizer/stats/child_stats_deriver.cpp +++ b/src/optimizer/stats/child_stats_deriver.cpp @@ -27,7 +27,7 @@ vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, gexpr_ = gexpr; memo_ = memo; output_ = vector(gexpr->GetChildrenGroupsSize(), ExprSet{}); - gexpr->Op().Accept(this); + gexpr->Op()->Accept(this); return std::move(output_); } diff --git a/src/optimizer/stats/stats_calculator.cpp b/src/optimizer/stats/stats_calculator.cpp index d086938a817..b327aa2df8b 100644 --- a/src/optimizer/stats/stats_calculator.cpp +++ b/src/optimizer/stats/stats_calculator.cpp @@ -33,7 +33,7 @@ void StatsCalculator::CalculateStats(GroupExpression *gexpr, memo_ = memo; required_cols_ = required_cols; txn_ = txn; - gexpr->Op().Accept(this); + gexpr->Op()->Accept(this); } void StatsCalculator::Visit(const LogicalGet *op) { diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index f1ffd6add66..242d62896dc 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -348,7 +348,8 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { std::vector child_groups = {gexpr->GetGroupID()}; std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::make_shared( + std::make_shared(Operator()), child_groups); std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); @@ -367,28 +368,28 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Op()->GetOpType()); + auto join_op = group_expr->Op()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - auto get_op = l_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); + auto get_op = l_group_expr->Op()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Op().GetType()); - auto filter_op = r_group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Op()->GetOpType()); + auto filter_op = r_group_expr->Op()->As(); EXPECT_EQ(1, filter_op->predicates.size()); EXPECT_TRUE(filter_op->predicates[0].expr->ExactlyEquals(*predicates[1])); // Check get below filter group_expr = GetSingleGroupExpression(memo, r_group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - get_op = group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); + get_op = group_expr->Op()->As(); EXPECT_TRUE(get_op->predicates.empty()); txn_manager.CommitTransaction(txn); @@ -435,7 +436,8 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { std::vector child_groups = {gexpr->GetGroupID()}; std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::make_shared( + std::make_shared(Operator()), child_groups); std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); @@ -454,21 +456,21 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Op()->GetOpType()); + auto join_op = group_expr->Op()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - auto get_op = l_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); + auto get_op = l_group_expr->Op()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::Get, r_group_expr->Op().GetType()); - get_op = r_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, r_group_expr->Op()->GetOpType()); + get_op = r_group_expr->Op()->As(); EXPECT_EQ(1, get_op->predicates.size()); EXPECT_TRUE(get_op->predicates[0].expr->ExactlyEquals(*predicates[1])); From 3266f2955d70740b1388ca6b50782a96f91dce1d Mon Sep 17 00:00:00 2001 From: Newton Xie Date: Sat, 27 Apr 2019 02:34:38 -0400 Subject: [PATCH 06/14] Adding AbstractNodeExpression interface. --- src/include/optimizer/abstract_node.h | 2 +- .../optimizer/abstract_node_expression.h | 44 +++ src/include/optimizer/binding.h | 13 +- .../optimizer/cost_model/default_cost_model.h | 2 +- .../cost_model/postgres_cost_model.h | 2 +- .../optimizer/cost_model/trivial_cost_model.h | 2 +- src/include/optimizer/group_expression.h | 2 +- src/include/optimizer/operator_expression.h | 23 +- src/include/optimizer/operator_node.h | 2 +- src/include/optimizer/operators.h | 106 +++--- src/include/optimizer/optimizer_metadata.h | 11 +- src/include/optimizer/rule.h | 6 +- src/include/optimizer/rule_impls.h | 156 ++++----- src/optimizer/binding.cpp | 20 +- src/optimizer/child_property_deriver.cpp | 2 +- src/optimizer/group.cpp | 10 +- src/optimizer/group_expression.cpp | 4 +- src/optimizer/input_column_deriver.cpp | 2 +- src/optimizer/memo.cpp | 10 +- src/optimizer/operator_expression.cpp | 18 +- src/optimizer/operator_node.cpp | 2 +- src/optimizer/operators.cpp | 218 ++++++------- src/optimizer/optimizer.cpp | 6 +- src/optimizer/optimizer_task.cpp | 16 +- src/optimizer/plan_generator.cpp | 2 +- src/optimizer/rule.cpp | 2 +- src/optimizer/rule_impls.cpp | 305 +++++++++--------- src/optimizer/stats/child_stats_deriver.cpp | 2 +- src/optimizer/stats/stats_calculator.cpp | 2 +- test/optimizer/optimizer_rule_test.cpp | 26 +- test/optimizer/optimizer_test.cpp | 28 +- 31 files changed, 543 insertions(+), 503 deletions(-) create mode 100644 src/include/optimizer/abstract_node_expression.h diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h index 4a076751abc..2b5b5f40f4b 100644 --- a/src/include/optimizer/abstract_node.h +++ b/src/include/optimizer/abstract_node.h @@ -83,7 +83,7 @@ enum class OpType { class OperatorVisitor; struct AbstractNode { - AbstractNode(AbstractNode *node) : node(node) {} + AbstractNode(std::shared_ptr node) : node(node) {} ~AbstractNode() {} diff --git a/src/include/optimizer/abstract_node_expression.h b/src/include/optimizer/abstract_node_expression.h new file mode 100644 index 00000000000..01dbc40683c --- /dev/null +++ b/src/include/optimizer/abstract_node_expression.h @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// abstract_node_expression.h +// +// Identification: src/include/optimizer/abstract_node_expression.h +// +// Copyright (c) 2015-19, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/abstract_node.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +//===--------------------------------------------------------------------===// +// Abstract Node Expression +//===--------------------------------------------------------------------===// +class AbstractNodeExpression { + public: + AbstractNodeExpression() {} + + ~AbstractNodeExpression() {} + + virtual void PushChild(std::shared_ptr child) = 0; + + virtual void PopChild() = 0; + + virtual const std::vector> &Children() const = 0; + + virtual const std::shared_ptr Node() const = 0; + + virtual const std::string GetInfo() const = 0; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index 7a6d772813d..1bb16eddb6b 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -12,13 +12,14 @@ #pragma once +#include "operator_expression.h" #include "optimizer/operator_node.h" #include "optimizer/group.h" #include "optimizer/pattern.h" + #include #include #include -#include "operator_expression.h" namespace peloton { namespace optimizer { @@ -37,7 +38,7 @@ class BindingIterator { virtual bool HasNext() = 0; - virtual std::shared_ptr Next() = 0; + virtual std::shared_ptr Next() = 0; protected: Memo &memo_; @@ -50,7 +51,7 @@ class GroupBindingIterator : public BindingIterator { bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupID group_id_; @@ -70,7 +71,7 @@ class GroupExprBindingIterator : public BindingIterator { bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupExpression* gexpr_; @@ -78,8 +79,8 @@ class GroupExprBindingIterator : public BindingIterator { bool first_; bool has_next_; - std::shared_ptr current_binding_; - std::vector>> + std::shared_ptr current_binding_; + std::vector>> children_bindings_; std::vector children_bindings_pos_; }; diff --git a/src/include/optimizer/cost_model/default_cost_model.h b/src/include/optimizer/cost_model/default_cost_model.h index d5b9be3c82b..c8decace1f0 100644 --- a/src/include/optimizer/cost_model/default_cost_model.h +++ b/src/include/optimizer/cost_model/default_cost_model.h @@ -34,7 +34,7 @@ class DefaultCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op()->Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; } diff --git a/src/include/optimizer/cost_model/postgres_cost_model.h b/src/include/optimizer/cost_model/postgres_cost_model.h index 87e9858cd06..58224791c8e 100644 --- a/src/include/optimizer/cost_model/postgres_cost_model.h +++ b/src/include/optimizer/cost_model/postgres_cost_model.h @@ -39,7 +39,7 @@ class PostgresCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op()->Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; }; diff --git a/src/include/optimizer/cost_model/trivial_cost_model.h b/src/include/optimizer/cost_model/trivial_cost_model.h index 5e3d7fd195f..9c8a1ddb58c 100644 --- a/src/include/optimizer/cost_model/trivial_cost_model.h +++ b/src/include/optimizer/cost_model/trivial_cost_model.h @@ -41,7 +41,7 @@ class TrivialCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op()->Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; }; diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index 6a25871cbb1..8bf2748336a 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -46,7 +46,7 @@ class GroupExpression { GroupID GetChildGroupId(int child_idx) const; - std::shared_ptr Op() const; + std::shared_ptr Node() const; double GetCost(std::shared_ptr& requirements) const; diff --git a/src/include/optimizer/operator_expression.h b/src/include/optimizer/operator_expression.h index a37020915f4..2bdd4018b67 100644 --- a/src/include/optimizer/operator_expression.h +++ b/src/include/optimizer/operator_expression.h @@ -2,16 +2,17 @@ // // Peloton // -// op_expression.h +// operator_expression.h // -// Identification: src/include/optimizer/op_expression.h +// Identification: src/include/optimizer/operator_expression.h // -// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// Copyright (c) 2015-19, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// #pragma once +#include "optimizer/abstract_node_expression.h" #include "optimizer/operator_node.h" #include @@ -21,25 +22,25 @@ namespace peloton { namespace optimizer { //===--------------------------------------------------------------------===// -// Operator Expr +// Operator Expression //===--------------------------------------------------------------------===// -class OperatorExpression { +class OperatorExpression : public AbstractNodeExpression { public: - OperatorExpression(Operator op); + OperatorExpression(std::shared_ptr node); - void PushChild(std::shared_ptr op); + void PushChild(std::shared_ptr child); void PopChild(); - const std::vector> &Children() const; + const std::vector> &Children() const; - const Operator &Op() const; + const std::shared_ptr Node() const; const std::string GetInfo() const; private: - Operator op; - std::vector> children; + std::shared_ptr node; + std::vector> children; }; } // namespace optimizer diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index 3a26c6daaa4..ec896978a75 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -57,7 +57,7 @@ class Operator : public AbstractNode { public: Operator(); - Operator(AbstractNode *node); + Operator(std::shared_ptr node); void Accept(OperatorVisitor *v) const; diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index 80ba32e94ef..5b3b9f04847 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -38,9 +38,9 @@ class PropertySort; //===--------------------------------------------------------------------===// // Leaf //===--------------------------------------------------------------------===// -class LeafOperator : OperatorNode { +class LeafOperator : public OperatorNode { public: - static Operator make(GroupID group); + static std::shared_ptr make(GroupID group); GroupID origin_group; }; @@ -50,7 +50,7 @@ class LeafOperator : OperatorNode { //===--------------------------------------------------------------------===// class LogicalGet : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id = 0, std::vector predicates = {}, std::shared_ptr table = nullptr, std::string alias = "", bool update = false); @@ -72,7 +72,7 @@ class LogicalGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalExternalFileGet : public OperatorNode { public: - static Operator make(oid_t get_id, ExternalFileFormat format, + static std::shared_ptr make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); @@ -94,7 +94,7 @@ class LogicalExternalFileGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalQueryDerivedGet : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id, std::string &alias, std::unordered_map> @@ -117,7 +117,7 @@ class LogicalQueryDerivedGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalFilter : public OperatorNode { public: - static Operator make(std::vector &filter); + static std::shared_ptr make(std::vector &filter); std::vector predicates; bool operator==(const AbstractNode &r) override; @@ -130,7 +130,7 @@ class LogicalFilter : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalProjection : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> &elements); std::vector> expressions; }; @@ -140,9 +140,9 @@ class LogicalProjection : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDependentJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); bool operator==(const AbstractNode &r) override; @@ -156,9 +156,9 @@ class LogicalDependentJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalMarkJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); bool operator==(const AbstractNode &r) override; @@ -172,9 +172,9 @@ class LogicalMarkJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalSingleJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); bool operator==(const AbstractNode &r) override; @@ -188,9 +188,9 @@ class LogicalSingleJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalInnerJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); bool operator==(const AbstractNode &r) override; @@ -204,7 +204,7 @@ class LogicalInnerJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalLeftJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -214,7 +214,7 @@ class LogicalLeftJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalRightJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -224,7 +224,7 @@ class LogicalRightJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalOuterJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -234,7 +234,7 @@ class LogicalOuterJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalSemiJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -245,12 +245,12 @@ class LogicalSemiJoin : public OperatorNode { class LogicalAggregateAndGroupBy : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make( + static std::shared_ptr make( std::vector> &columns); - static Operator make( + static std::shared_ptr make( std::vector> &columns, std::vector &having); @@ -266,7 +266,7 @@ class LogicalAggregateAndGroupBy //===--------------------------------------------------------------------===// class LogicalInsert : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector *columns, const std::vector { class LogicalInsertSelect : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -291,7 +291,7 @@ class LogicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDistinct : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -299,7 +299,7 @@ class LogicalDistinct : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalLimit : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( int64_t offset, int64_t limit, std::vector &&sort_exprs, std::vector &&sort_ascending); @@ -318,7 +318,7 @@ class LogicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDelete : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -329,7 +329,7 @@ class LogicalDelete : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalUpdate : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector> *updates); @@ -343,7 +343,7 @@ class LogicalUpdate : public OperatorNode { class LogicalExportExternalFile : public OperatorNode { public: - static Operator make(ExternalFileFormat format, std::string file_name, + static std::shared_ptr make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); bool operator==(const AbstractNode &r) override; @@ -362,7 +362,7 @@ class LogicalExportExternalFile //===--------------------------------------------------------------------===// class DummyScan : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -370,7 +370,7 @@ class DummyScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSeqScan : public OperatorNode { public: - static Operator make(oid_t get_id, + static std::shared_ptr make(oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, @@ -393,7 +393,7 @@ class PhysicalSeqScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalIndexScan : public OperatorNode { public: - static Operator make(oid_t get_id, + static std::shared_ptr make(oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update, @@ -426,7 +426,7 @@ class PhysicalIndexScan : public OperatorNode { //===--------------------------------------------------------------------===// class ExternalFileScan : public OperatorNode { public: - static Operator make(oid_t get_id, ExternalFileFormat format, + static std::shared_ptr make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); @@ -448,7 +448,7 @@ class ExternalFileScan : public OperatorNode { //===--------------------------------------------------------------------===// class QueryDerivedScan : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id, std::string alias, std::unordered_map> @@ -471,7 +471,7 @@ class QueryDerivedScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalOrderBy : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -479,7 +479,7 @@ class PhysicalOrderBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalLimit : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( int64_t offset, int64_t limit, std::vector sort_columns, std::vector sort_ascending); @@ -498,7 +498,7 @@ class PhysicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInnerNLJoin : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -519,7 +519,7 @@ class PhysicalInnerNLJoin : public OperatorNode { class PhysicalLeftNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -529,7 +529,7 @@ class PhysicalLeftNLJoin : public OperatorNode { class PhysicalRightNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -539,7 +539,7 @@ class PhysicalRightNLJoin : public OperatorNode { class PhysicalOuterNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -548,7 +548,7 @@ class PhysicalOuterNLJoin : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInnerHashJoin : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); @@ -569,7 +569,7 @@ class PhysicalInnerHashJoin : public OperatorNode { class PhysicalLeftHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -579,7 +579,7 @@ class PhysicalLeftHashJoin : public OperatorNode { class PhysicalRightHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -589,7 +589,7 @@ class PhysicalRightHashJoin : public OperatorNode { class PhysicalOuterHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -598,7 +598,7 @@ class PhysicalOuterHashJoin : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInsert : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector *columns, const std::vector { class PhysicalInsertSelect : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -623,7 +623,7 @@ class PhysicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalDelete : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -633,7 +633,7 @@ class PhysicalDelete : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalUpdate : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector> *updates); @@ -647,7 +647,7 @@ class PhysicalUpdate : public OperatorNode { class PhysicalExportExternalFile : public OperatorNode { public: - static Operator make(ExternalFileFormat format, std::string file_name, + static std::shared_ptr make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); bool operator==(const AbstractNode &r) override; @@ -666,7 +666,7 @@ class PhysicalExportExternalFile //===--------------------------------------------------------------------===// class PhysicalHashGroupBy : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> columns, std::vector having); @@ -682,7 +682,7 @@ class PhysicalHashGroupBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSortGroupBy : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> columns, std::vector having); @@ -698,12 +698,12 @@ class PhysicalSortGroupBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalAggregate : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; class PhysicalDistinct : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; } // namespace optimizer diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 85782dd09bf..57dcb2ec7d8 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -53,23 +53,22 @@ class OptimizerMetadata { } std::shared_ptr MakeGroupExpression( - std::shared_ptr expr) { + std::shared_ptr expr) { std::vector child_groups; for (auto &child : expr->Children()) { auto gexpr = MakeGroupExpression(child); memo.InsertExpression(gexpr, false); child_groups.push_back(gexpr->GetGroupID()); } - return std::make_shared(std::make_shared(expr->Op()), - std::move(child_groups)); + return std::make_shared(expr->Node(), std::move(child_groups)); } - bool RecordTransformedExpression(std::shared_ptr expr, + bool RecordTransformedExpression(std::shared_ptr expr, std::shared_ptr &gexpr) { return RecordTransformedExpression(expr, gexpr, UNDEFINED_GROUP); } - bool RecordTransformedExpression(std::shared_ptr expr, + bool RecordTransformedExpression(std::shared_ptr expr, std::shared_ptr &gexpr, GroupID target_group) { gexpr = MakeGroupExpression(expr); @@ -77,7 +76,7 @@ class OptimizerMetadata { } // TODO(boweic): check if we really need to use shared_ptr - void ReplaceRewritedExpression(std::shared_ptr expr, + void ReplaceRewritedExpression(std::shared_ptr expr, GroupID target_group) { memo.EraseExpression(target_group); memo.InsertExpression(MakeGroupExpression(expr), target_group, false); diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 4ea78a630c6..d561bdf2f62 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -74,7 +74,7 @@ class Rule { * * @return If the rule is applicable, return true, otherwise return false */ - virtual bool Check(std::shared_ptr expr, + virtual bool Check(std::shared_ptr expr, OptimizeContext *context) const = 0; /** @@ -85,8 +85,8 @@ class Rule { * @param context The current optimization context */ virtual void Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const = 0; inline RuleType GetType() { return type_; } diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 57902e744a9..9296f2ce3a9 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -30,11 +30,11 @@ class InnerJoinCommutativity : public Rule { public: InnerJoinCommutativity(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -46,11 +46,11 @@ class InnerJoinAssociativity : public Rule { public: InnerJoinAssociativity(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -65,11 +65,11 @@ class GetToSeqScan : public Rule { public: GetToSeqScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -77,11 +77,11 @@ class LogicalExternalFileGetToPhysical : public Rule { public: LogicalExternalFileGetToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -93,11 +93,11 @@ class GetToDummyScan : public Rule { public: GetToDummyScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -108,11 +108,11 @@ class GetToIndexScan : public Rule { public: GetToIndexScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -123,11 +123,11 @@ class LogicalQueryDerivedGetToPhysical : public Rule { public: LogicalQueryDerivedGetToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -138,11 +138,11 @@ class LogicalDeleteToPhysical : public Rule { public: LogicalDeleteToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -153,11 +153,11 @@ class LogicalUpdateToPhysical : public Rule { public: LogicalUpdateToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -168,11 +168,11 @@ class LogicalInsertToPhysical : public Rule { public: LogicalInsertToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -183,11 +183,11 @@ class LogicalInsertSelectToPhysical : public Rule { public: LogicalInsertSelectToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -198,11 +198,11 @@ class LogicalGroupByToHashGroupBy : public Rule { public: LogicalGroupByToHashGroupBy(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -213,11 +213,11 @@ class LogicalAggregateToPhysical : public Rule { public: LogicalAggregateToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -228,11 +228,11 @@ class InnerJoinToInnerNLJoin : public Rule { public: InnerJoinToInnerNLJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -243,11 +243,11 @@ class InnerJoinToInnerHashJoin : public Rule { public: InnerJoinToInnerHashJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -258,11 +258,11 @@ class ImplementDistinct : public Rule { public: ImplementDistinct(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -273,11 +273,11 @@ class ImplementLimit : public Rule { public: ImplementLimit(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -288,11 +288,11 @@ class LogicalExportToPhysicalExport : public Rule { public: LogicalExportToPhysicalExport(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -310,11 +310,11 @@ class PushFilterThroughJoin : public Rule { public: PushFilterThroughJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -325,11 +325,11 @@ class CombineConsecutiveFilter : public Rule { public: CombineConsecutiveFilter(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -341,11 +341,11 @@ class PushFilterThroughAggregation : public Rule { public: PushFilterThroughAggregation(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; /** @@ -357,11 +357,11 @@ class EmbedFilterIntoGet : public Rule { public: EmbedFilterIntoGet(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -384,11 +384,11 @@ class MarkJoinToInnerJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// @@ -400,11 +400,11 @@ class SingleJoinToInnerJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -417,11 +417,11 @@ class PullFilterThroughMarkJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -434,11 +434,11 @@ class PullFilterThroughAggregation : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; } // namespace optimizer diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index ab6d7a03bd8..8333f8b5055 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -67,7 +67,7 @@ bool GroupBindingIterator::HasNext() { return current_iterator_ != nullptr; } -std::shared_ptr GroupBindingIterator::Next() { +std::shared_ptr GroupBindingIterator::Next() { if (pattern_->Type() == OpType::Leaf) { current_item_index_ = num_group_items_; return std::make_shared(LeafOperator::make(group_id_)); @@ -85,9 +85,9 @@ GroupExprBindingIterator::GroupExprBindingIterator( pattern_(pattern), first_(true), has_next_(false), - // TODO(ncx): fix once OperatorExpression is abstracted - current_binding_(std::make_shared(*(Operator *)gexpr->Op().get())) { - if (gexpr->Op()->GetOpType() != pattern->Type()) { + // TODO(ncx): needs workaround when Node is not an Operator + current_binding_(std::make_shared(gexpr->Node())) { + if (gexpr->Node()->GetOpType() != pattern->Type()) { return; } @@ -101,14 +101,14 @@ GroupExprBindingIterator::GroupExprBindingIterator( LOG_TRACE( "Attempting to bind on group %d with expression of type %s, children " "size %lu", - gexpr->GetGroupID(), gexpr->Op()->GetName().c_str(), child_groups.size()); + gexpr->GetGroupID(), gexpr->Node()->GetName().c_str(), child_groups.size()); // Find all bindings for children children_bindings_.resize(child_groups.size(), {}); children_bindings_pos_.resize(child_groups.size(), 0); for (size_t i = 0; i < child_groups.size(); ++i) { // Try to find a match in the given group - std::vector> &child_bindings = + std::vector> &child_bindings = children_bindings_[i]; GroupBindingIterator iterator(memo_, child_groups[i], child_patterns[i]); @@ -138,7 +138,7 @@ bool GroupExprBindingIterator::HasNext() { // The first child to be modified int first_modified_idx = children_bindings_pos_.size() - 1; for (; first_modified_idx >= 0; --first_modified_idx) { - const std::vector> &child_binding = + const std::vector> &child_binding = children_bindings_[first_modified_idx]; // Try to increment idx from the back @@ -162,9 +162,9 @@ bool GroupExprBindingIterator::HasNext() { // Add new children to end for (size_t offset = first_modified_idx; offset < children_bindings_pos_.size(); ++offset) { - const std::vector> &child_binding = + const std::vector> &child_binding = children_bindings_[offset]; - std::shared_ptr binding = + std::shared_ptr binding = child_binding[children_bindings_pos_[offset]]; current_binding_->PushChild(binding); } @@ -173,7 +173,7 @@ bool GroupExprBindingIterator::HasNext() { return has_next_; } -std::shared_ptr GroupExprBindingIterator::Next() { +std::shared_ptr GroupExprBindingIterator::Next() { return current_binding_; } diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index 9b8adebbb68..c41562e07a7 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -38,7 +38,7 @@ ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, output_.clear(); memo_ = memo; gexpr_ = gexpr; - gexpr->Op()->Accept(this); + gexpr->Node()->Accept(this); return move(output_); } diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 33bc32484a0..e3e00eba120 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -31,7 +31,7 @@ void Group::AddExpression(std::shared_ptr expr, expr->SetGroupID(id_); if (enforced) enforced_exprs_.push_back(expr); - else if (expr->Op()->IsPhysical()) + else if (expr->Node()->IsPhysical()) physical_expressions_.push_back(expr); else logical_expressions_.push_back(expr); @@ -40,7 +40,7 @@ void Group::AddExpression(std::shared_ptr expr, bool Group::SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties) { LOG_TRACE("Adding expression cost on group %d with op %s, req %s", - expr->GetGroupID(), expr->Op()->GetName().c_str(), + expr->GetGroupID(), expr->Node()->GetName().c_str(), properties->ToString().c_str()); auto it = lowest_cost_expressions_.find(properties); if (it == lowest_cost_expressions_.end() || std::get<0>(it->second) > cost) { @@ -86,7 +86,7 @@ const std::string Group::GetInfo(int num_indent) const { for (auto expr : logical_expressions_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op()->GetName() << std::endl; + << expr->Node()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) @@ -102,7 +102,7 @@ const std::string Group::GetInfo(int num_indent) const { << "physical_expressions_: \n"; for (auto expr : physical_expressions_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op()->GetName() << std::endl; + << expr->Node()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) @@ -119,7 +119,7 @@ const std::string Group::GetInfo(int num_indent) const { << "enforced_exprs_: \n"; for (auto expr : enforced_exprs_) { os << StringUtil::Indent(num_indent + 4) - << expr->Op()->GetName() << std::endl; + << expr->Node()->GetName() << std::endl; const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); if (ChildGroupIDs.size() > 0) { os << StringUtil::Indent(num_indent + 6) diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 5db0fb32a82..ae160ded2a1 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -43,7 +43,7 @@ GroupID GroupExpression::GetChildGroupId(int child_idx) const { return child_groups[child_idx]; } -std::shared_ptr GroupExpression::Op() const { +std::shared_ptr GroupExpression::Node() const { return std::shared_ptr(node); } @@ -87,7 +87,7 @@ hash_t GroupExpression::Hash() const { } bool GroupExpression::operator==(const GroupExpression &r) { - return (*node == *r.Op()) && (child_groups == r.child_groups); + return (*node == *r.Node()) && (child_groups == r.child_groups); } void GroupExpression::SetRuleExplored(Rule *rule) { diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index 7104485454e..c2f94cc2d3d 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -43,7 +43,7 @@ InputColumnDeriver::DeriveInputColumns( gexpr_ = gexpr; required_cols_ = move(required_cols); memo_ = memo; - gexpr->Op()->Accept(this); + gexpr->Node()->Accept(this); return move(output_input_cols_); } diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index 69e1e8a54f4..db50e2f2936 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -31,8 +31,8 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced) { // If leaf, then just return - if (gexpr->Op()->GetOpType() == OpType::Leaf) { - const LeafOperator *leaf = gexpr->Op()->As(); + if (gexpr->Node()->GetOpType() == OpType::Leaf) { + const LeafOperator *leaf = gexpr->Node()->As(); PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == leaf->origin_group); gexpr->SetGroupID(leaf->origin_group); @@ -91,14 +91,14 @@ GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { GroupID new_group_id = groups_.size(); // Find out the table alias that this group represents std::unordered_set table_aliases; - auto op_type = gexpr->Op()->GetOpType(); + auto op_type = gexpr->Node()->GetOpType(); if (op_type == OpType::Get) { // For base group, the table alias can get directly from logical get - const LogicalGet *logical_get = gexpr->Op()->As(); + const LogicalGet *logical_get = gexpr->Node()->As(); table_aliases.insert(logical_get->table_alias); } else if (op_type == OpType::LogicalQueryDerivedGet) { const LogicalQueryDerivedGet *query_get = - gexpr->Op()->As(); + gexpr->Node()->As(); table_aliases.insert(query_get->table_alias); } else { // For other groups, need to aggregate the table alias from children diff --git a/src/optimizer/operator_expression.cpp b/src/optimizer/operator_expression.cpp index 9ba9d2de706..ba96a124e27 100644 --- a/src/optimizer/operator_expression.cpp +++ b/src/optimizer/operator_expression.cpp @@ -2,11 +2,11 @@ // // Peloton // -// op_expression.cpp +// operator_expression.cpp // -// Identification: src/optimizer/op_expression.cpp +// Identification: src/optimizer/operator_expression.cpp // -// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// Copyright (c) 2015-19, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -20,26 +20,26 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Operator Expression //===--------------------------------------------------------------------===// -OperatorExpression::OperatorExpression(Operator op) : op(op) {} +OperatorExpression::OperatorExpression(std::shared_ptr node) : node(node) {} -void OperatorExpression::PushChild(std::shared_ptr op) { - children.push_back(op); +void OperatorExpression::PushChild(std::shared_ptr node) { + children.push_back(node); } void OperatorExpression::PopChild() { children.pop_back(); } -const std::vector> +const std::vector> &OperatorExpression::Children() const { return children; } -const Operator &OperatorExpression::Op() const { return op; } +const std::shared_ptr OperatorExpression::Node() const { return node; } const std::string OperatorExpression::GetInfo() const { std::string info = "{"; { info += "\"Op\":"; - info += "\"" + op.GetName() + "\","; + info += "\"" + node->GetName() + "\","; if (!children.empty()) { info += "\"Children\":["; { diff --git a/src/optimizer/operator_node.cpp b/src/optimizer/operator_node.cpp index 38e03310b94..cb80ab5bf39 100644 --- a/src/optimizer/operator_node.cpp +++ b/src/optimizer/operator_node.cpp @@ -21,7 +21,7 @@ namespace optimizer { //===--------------------------------------------------------------------===// Operator::Operator() : AbstractNode(nullptr) {} -Operator::Operator(AbstractNode *node) : AbstractNode(node) {} +Operator::Operator(std::shared_ptr node) : AbstractNode(node) {} void Operator::Accept(OperatorVisitor *v) const { node->Accept(v); } diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index f75b1159a42..17038d1b1eb 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -21,16 +21,16 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Leaf //===--------------------------------------------------------------------===// -Operator LeafOperator::make(GroupID group) { +std::shared_ptr LeafOperator::make(GroupID group) { LeafOperator *op = new LeafOperator; op->origin_group = group; - return Operator(op); + return std::make_shared(std::shared_ptr(op)); } //===--------------------------------------------------------------------===// // Get //===--------------------------------------------------------------------===// -Operator LogicalGet::make(oid_t get_id, +std::shared_ptr LogicalGet::make(oid_t get_id, std::vector predicates, std::shared_ptr table, std::string alias, bool update) { @@ -41,7 +41,7 @@ Operator LogicalGet::make(oid_t get_id, get->is_for_update = update; get->get_id = get_id; util::to_lower_string(get->table_alias); - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } hash_t LogicalGet::Hash() const { @@ -67,9 +67,9 @@ bool LogicalGet::operator==(const AbstractNode &r) { // External file get //===--------------------------------------------------------------------===// -Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, - std::string file_name, char delimiter, - char quote, char escape) { +std::shared_ptr LogicalExternalFileGet::make( + oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, + char quote, char escape) { auto *get = new LogicalExternalFileGet(); get->get_id = get_id; get->format = format; @@ -77,7 +77,7 @@ Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, get->delimiter = delimiter; get->quote = quote; get->escape = escape; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } bool LogicalExternalFileGet::operator==(const AbstractNode &node) { @@ -103,7 +103,7 @@ hash_t LogicalExternalFileGet::Hash() const { //===--------------------------------------------------------------------===// // Query derived get //===--------------------------------------------------------------------===// -Operator LogicalQueryDerivedGet::make( +std::shared_ptr LogicalQueryDerivedGet::make( oid_t get_id, std::string &alias, std::unordered_map> @@ -113,7 +113,7 @@ Operator LogicalQueryDerivedGet::make( get->alias_to_expr_map = alias_to_expr_map; get->get_id = get_id; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } bool LogicalQueryDerivedGet::operator==(const AbstractNode &node) { @@ -132,10 +132,10 @@ hash_t LogicalQueryDerivedGet::Hash() const { //===--------------------------------------------------------------------===// // Select //===--------------------------------------------------------------------===// -Operator LogicalFilter::make(std::vector &filter) { +std::shared_ptr LogicalFilter::make(std::vector &filter) { LogicalFilter *select = new LogicalFilter; select->predicates = std::move(filter); - return Operator(select); + return std::make_shared(std::shared_ptr(select)); } hash_t LogicalFilter::Hash() const { @@ -158,27 +158,27 @@ bool LogicalFilter::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // Project //===--------------------------------------------------------------------===// -Operator LogicalProjection::make( +std::shared_ptr LogicalProjection::make( std::vector> &elements) { LogicalProjection *projection = new LogicalProjection; projection->expressions = std::move(elements); - return Operator(projection); + return std::make_shared(std::shared_ptr(projection)); } //===--------------------------------------------------------------------===// // DependentJoin //===--------------------------------------------------------------------===// -Operator LogicalDependentJoin::make() { +std::shared_ptr LogicalDependentJoin::make() { LogicalDependentJoin *join = new LogicalDependentJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalDependentJoin::make( +std::shared_ptr LogicalDependentJoin::make( std::vector &conditions) { LogicalDependentJoin *join = new LogicalDependentJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalDependentJoin::Hash() const { @@ -204,16 +204,16 @@ bool LogicalDependentJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // MarkJoin //===--------------------------------------------------------------------===// -Operator LogicalMarkJoin::make() { +std::shared_ptr LogicalMarkJoin::make() { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalMarkJoin::make(std::vector &conditions) { +std::shared_ptr LogicalMarkJoin::make(std::vector &conditions) { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalMarkJoin::Hash() const { @@ -238,16 +238,16 @@ bool LogicalMarkJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // SingleJoin //===--------------------------------------------------------------------===// -Operator LogicalSingleJoin::make() { +std::shared_ptr LogicalSingleJoin::make() { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalSingleJoin::make(std::vector &conditions) { +std::shared_ptr LogicalSingleJoin::make(std::vector &conditions) { LogicalSingleJoin *join = new LogicalSingleJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalSingleJoin::Hash() const { @@ -272,16 +272,16 @@ bool LogicalSingleJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // InnerJoin //===--------------------------------------------------------------------===// -Operator LogicalInnerJoin::make() { +std::shared_ptr LogicalInnerJoin::make() { LogicalInnerJoin *join = new LogicalInnerJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalInnerJoin::make(std::vector &conditions) { +std::shared_ptr LogicalInnerJoin::make(std::vector &conditions) { LogicalInnerJoin *join = new LogicalInnerJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalInnerJoin::Hash() const { @@ -306,66 +306,66 @@ bool LogicalInnerJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // LeftJoin //===--------------------------------------------------------------------===// -Operator LogicalLeftJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalLeftJoin::make(expression::AbstractExpression *condition) { LogicalLeftJoin *join = new LogicalLeftJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightJoin //===--------------------------------------------------------------------===// -Operator LogicalRightJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalRightJoin::make(expression::AbstractExpression *condition) { LogicalRightJoin *join = new LogicalRightJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterJoin //===--------------------------------------------------------------------===// -Operator LogicalOuterJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalOuterJoin::make(expression::AbstractExpression *condition) { LogicalOuterJoin *join = new LogicalOuterJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterJoin //===--------------------------------------------------------------------===// -Operator LogicalSemiJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalSemiJoin::make(expression::AbstractExpression *condition) { LogicalSemiJoin *join = new LogicalSemiJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // Aggregate //===--------------------------------------------------------------------===// -Operator LogicalAggregateAndGroupBy::make() { +std::shared_ptr LogicalAggregateAndGroupBy::make() { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = {}; - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } -Operator LogicalAggregateAndGroupBy::make( +std::shared_ptr LogicalAggregateAndGroupBy::make( std::vector> &columns) { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = move(columns); - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } -Operator LogicalAggregateAndGroupBy::make( +std::shared_ptr LogicalAggregateAndGroupBy::make( std::vector> &columns, std::vector &having) { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = move(columns); group_by->having = move(having); - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } bool LogicalAggregateAndGroupBy::operator==(const AbstractNode &node) { @@ -390,7 +390,7 @@ hash_t LogicalAggregateAndGroupBy::Hash() const { //===--------------------------------------------------------------------===// // Insert //===--------------------------------------------------------------------===// -Operator LogicalInsert::make( +std::shared_ptr LogicalInsert::make( std::shared_ptr target_table, const std::vector *columns, const std::vectortarget_table = target_table; insert_op->columns = columns; insert_op->values = values; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } -Operator LogicalInsertSelect::make( +std::shared_ptr LogicalInsertSelect::make( std::shared_ptr target_table) { LogicalInsertSelect *insert_op = new LogicalInsertSelect; insert_op->target_table = target_table; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // Delete //===--------------------------------------------------------------------===// -Operator LogicalDelete::make( +std::shared_ptr LogicalDelete::make( std::shared_ptr target_table) { LogicalDelete *delete_op = new LogicalDelete; delete_op->target_table = target_table; - return Operator(delete_op); + return std::make_shared(std::shared_ptr(delete_op)); } //===--------------------------------------------------------------------===// // Update //===--------------------------------------------------------------------===// -Operator LogicalUpdate::make( +std::shared_ptr LogicalUpdate::make( std::shared_ptr target_table, const std::vector> * updates) { LogicalUpdate *update_op = new LogicalUpdate; update_op->target_table = target_table; update_op->updates = updates; - return Operator(update_op); + return std::make_shared(std::shared_ptr(update_op)); } //===--------------------------------------------------------------------===// // Distinct //===--------------------------------------------------------------------===// -Operator LogicalDistinct::make() { - LogicalDistinct *distinct = new LogicalDistinct; - return Operator(distinct); +std::shared_ptr LogicalDistinct::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // Limit //===--------------------------------------------------------------------===// -Operator LogicalLimit::make( +std::shared_ptr LogicalLimit::make( int64_t offset, int64_t limit, std::vector &&sort_exprs, std::vector &&sort_ascending) { @@ -452,13 +451,13 @@ Operator LogicalLimit::make( limit_op->limit = limit; limit_op->sort_exprs = std::move(sort_exprs); limit_op->sort_ascending = std::move(sort_ascending); - return Operator(limit_op); + return std::make_shared(std::shared_ptr(limit_op)); } //===--------------------------------------------------------------------===// // External file output //===--------------------------------------------------------------------===// -Operator LogicalExportExternalFile::make(ExternalFileFormat format, +std::shared_ptr LogicalExportExternalFile::make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *export_op = new LogicalExportExternalFile(); @@ -467,7 +466,7 @@ Operator LogicalExportExternalFile::make(ExternalFileFormat format, export_op->delimiter = delimiter; export_op->quote = quote; export_op->escape = escape; - return Operator(export_op); + return std::make_shared(std::shared_ptr(export_op)); } bool LogicalExportExternalFile::operator==(const AbstractNode &node) { @@ -493,15 +492,14 @@ hash_t LogicalExportExternalFile::Hash() const { //===--------------------------------------------------------------------===// // DummyScan //===--------------------------------------------------------------------===// -Operator DummyScan::make() { - DummyScan *dummy = new DummyScan; - return Operator(dummy); +std::shared_ptr DummyScan::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // SeqScan //===--------------------------------------------------------------------===// -Operator PhysicalSeqScan::make( +std::shared_ptr PhysicalSeqScan::make( oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update) { @@ -513,7 +511,7 @@ Operator PhysicalSeqScan::make( scan->is_for_update = update; scan->get_id = get_id; - return Operator(scan); + return std::make_shared(std::shared_ptr(scan)); } bool PhysicalSeqScan::operator==(const AbstractNode &r) { @@ -538,7 +536,7 @@ hash_t PhysicalSeqScan::Hash() const { //===--------------------------------------------------------------------===// // IndexScan //===--------------------------------------------------------------------===// -Operator PhysicalIndexScan::make( +std::shared_ptr PhysicalIndexScan::make( oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update, oid_t index_id, std::vector key_column_id_list, @@ -556,7 +554,7 @@ Operator PhysicalIndexScan::make( scan->expr_type_list = std::move(expr_type_list); scan->value_list = std::move(value_list); - return Operator(scan); + return std::make_shared(std::shared_ptr(scan)); } bool PhysicalIndexScan::operator==(const AbstractNode &r) { @@ -588,7 +586,7 @@ hash_t PhysicalIndexScan::Hash() const { //===--------------------------------------------------------------------===// // Physical external file scan //===--------------------------------------------------------------------===// -Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, +std::shared_ptr ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *get = new ExternalFileScan(); @@ -598,7 +596,7 @@ Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, get->delimiter = delimiter; get->quote = quote; get->escape = escape; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } bool ExternalFileScan::operator==(const AbstractNode &node) { @@ -624,7 +622,7 @@ hash_t ExternalFileScan::Hash() const { //===--------------------------------------------------------------------===// // Query derived get //===--------------------------------------------------------------------===// -Operator QueryDerivedScan::make( +std::shared_ptr QueryDerivedScan::make( oid_t get_id, std::string alias, std::unordered_map> @@ -634,7 +632,7 @@ Operator QueryDerivedScan::make( get->alias_to_expr_map = alias_to_expr_map; get->get_id = get_id; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } bool QueryDerivedScan::operator==(const AbstractNode &node) { @@ -652,16 +650,14 @@ hash_t QueryDerivedScan::Hash() const { //===--------------------------------------------------------------------===// // OrderBy //===--------------------------------------------------------------------===// -Operator PhysicalOrderBy::make() { - PhysicalOrderBy *order_by = new PhysicalOrderBy; - - return Operator(order_by); +std::shared_ptr PhysicalOrderBy::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // PhysicalLimit //===--------------------------------------------------------------------===// -Operator PhysicalLimit::make( +std::shared_ptr PhysicalLimit::make( int64_t offset, int64_t limit, std::vector sort_exprs, std::vector sort_ascending) { @@ -670,13 +666,13 @@ Operator PhysicalLimit::make( limit_op->limit = limit; limit_op->sort_exprs = sort_exprs; limit_op->sort_acsending = sort_ascending; - return Operator(limit_op); + return std::make_shared(std::shared_ptr(limit_op)); } //===--------------------------------------------------------------------===// // InnerNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalInnerNLJoin::make( +std::shared_ptr PhysicalInnerNLJoin::make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { @@ -685,7 +681,7 @@ Operator PhysicalInnerNLJoin::make( join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t PhysicalInnerNLJoin::Hash() const { @@ -724,37 +720,37 @@ bool PhysicalInnerNLJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // LeftNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalLeftNLJoin::make( +std::shared_ptr PhysicalLeftNLJoin::make( std::shared_ptr join_predicate) { PhysicalLeftNLJoin *join = new PhysicalLeftNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalRightNLJoin::make( +std::shared_ptr PhysicalRightNLJoin::make( std::shared_ptr join_predicate) { PhysicalRightNLJoin *join = new PhysicalRightNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalOuterNLJoin::make( +std::shared_ptr PhysicalOuterNLJoin::make( std::shared_ptr join_predicate) { PhysicalOuterNLJoin *join = new PhysicalOuterNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // InnerHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalInnerHashJoin::make( +std::shared_ptr PhysicalInnerHashJoin::make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { @@ -762,7 +758,7 @@ Operator PhysicalInnerHashJoin::make( join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t PhysicalInnerHashJoin::Hash() const { @@ -801,37 +797,37 @@ bool PhysicalInnerHashJoin::operator==(const AbstractNode &r) { //===--------------------------------------------------------------------===// // LeftHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalLeftHashJoin::make( +std::shared_ptr PhysicalLeftHashJoin::make( std::shared_ptr join_predicate) { PhysicalLeftHashJoin *join = new PhysicalLeftHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalRightHashJoin::make( +std::shared_ptr PhysicalRightHashJoin::make( std::shared_ptr join_predicate) { PhysicalRightHashJoin *join = new PhysicalRightHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalOuterHashJoin::make( +std::shared_ptr PhysicalOuterHashJoin::make( std::shared_ptr join_predicate) { PhysicalOuterHashJoin *join = new PhysicalOuterHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // PhysicalInsert //===--------------------------------------------------------------------===// -Operator PhysicalInsert::make( +std::shared_ptr PhysicalInsert::make( std::shared_ptr target_table, const std::vector *columns, const std::vectortarget_table = target_table; insert_op->columns = columns; insert_op->values = values; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // PhysicalInsertSelect //===--------------------------------------------------------------------===// -Operator PhysicalInsertSelect::make( +std::shared_ptr PhysicalInsertSelect::make( std::shared_ptr target_table) { PhysicalInsertSelect *insert_op = new PhysicalInsertSelect; insert_op->target_table = target_table; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // PhysicalDelete //===--------------------------------------------------------------------===// -Operator PhysicalDelete::make( +std::shared_ptr PhysicalDelete::make( std::shared_ptr target_table) { PhysicalDelete *delete_op = new PhysicalDelete; delete_op->target_table = target_table; - return Operator(delete_op); + return std::make_shared(std::shared_ptr(delete_op)); } //===--------------------------------------------------------------------===// // PhysicalUpdate //===--------------------------------------------------------------------===// -Operator PhysicalUpdate::make( +std::shared_ptr PhysicalUpdate::make( std::shared_ptr target_table, const std::vector> * updates) { PhysicalUpdate *update = new PhysicalUpdate; update->target_table = target_table; update->updates = updates; - return Operator(update); + return std::make_shared(std::shared_ptr(update)); } //===--------------------------------------------------------------------===// // PhysicalExportExternalFile //===--------------------------------------------------------------------===// -Operator PhysicalExportExternalFile::make(ExternalFileFormat format, +std::shared_ptr PhysicalExportExternalFile::make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *export_op = new PhysicalExportExternalFile(); @@ -888,7 +884,7 @@ Operator PhysicalExportExternalFile::make(ExternalFileFormat format, export_op->delimiter = delimiter; export_op->quote = quote; export_op->escape = escape; - return Operator(export_op); + return std::make_shared(std::shared_ptr(export_op)); } bool PhysicalExportExternalFile::operator==(const AbstractNode &node) { @@ -914,13 +910,13 @@ hash_t PhysicalExportExternalFile::Hash() const { //===--------------------------------------------------------------------===// // PhysicalHashGroupBy //===--------------------------------------------------------------------===// -Operator PhysicalHashGroupBy::make( +std::shared_ptr PhysicalHashGroupBy::make( std::vector> columns, std::vector having) { PhysicalHashGroupBy *agg = new PhysicalHashGroupBy; agg->columns = columns; agg->having = move(having); - return Operator(agg); + return std::make_shared(std::shared_ptr(agg)); } bool PhysicalHashGroupBy::operator==(const AbstractNode &node) { @@ -945,13 +941,13 @@ hash_t PhysicalHashGroupBy::Hash() const { //===--------------------------------------------------------------------===// // PhysicalSortGroupBy //===--------------------------------------------------------------------===// -Operator PhysicalSortGroupBy::make( +std::shared_ptr PhysicalSortGroupBy::make( std::vector> columns, std::vector having) { PhysicalSortGroupBy *agg = new PhysicalSortGroupBy; agg->columns = std::move(columns); agg->having = move(having); - return Operator(agg); + return std::make_shared(std::shared_ptr(agg)); } bool PhysicalSortGroupBy::operator==(const AbstractNode &node) { @@ -976,17 +972,15 @@ hash_t PhysicalSortGroupBy::Hash() const { //===--------------------------------------------------------------------===// // PhysicalAggregate //===--------------------------------------------------------------------===// -Operator PhysicalAggregate::make() { - PhysicalAggregate *agg = new PhysicalAggregate; - return Operator(agg); +std::shared_ptr PhysicalAggregate::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // Physical Hash //===--------------------------------------------------------------------===// -Operator PhysicalDistinct::make() { - PhysicalDistinct *hash = new PhysicalDistinct; - return Operator(hash); +std::shared_ptr PhysicalDistinct::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index c12205c3c35..e7f648e9375 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -327,7 +327,7 @@ const std::string Optimizer::GetOperatorInfo( auto gexpr = group->GetBestExpression(required_props); os << std::endl << StringUtil::Indent(num_indent) << "operator name: " - << gexpr->Op()->GetName().c_str(); + << gexpr->Node()->GetName().c_str(); vector child_groups = gexpr->GetChildGroupIDs(); auto required_input_props = gexpr->GetInputProperties(required_props); @@ -383,9 +383,7 @@ unique_ptr Optimizer::ChooseBestPlan( } // Derive root plan - // TODO(ncx): fix once OperatorExpression is abstracted - shared_ptr op = - make_shared(*(Operator *)gexpr->Op().get()); + std::shared_ptr op(make_shared(gexpr->Node())); PlanGenerator generator; auto plan = generator.ConvertOpExpression(op, required_props, required_cols, diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index cd3f7e5c778..f52bff6c3ee 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -33,10 +33,10 @@ void OptimizerTask::ConstructValidRules( // Check if we can apply the rule // TODO(ncx): replace after pattern fix // bool root_pattern_mismatch = - // group_expr->Op()->GetOpType() != rule->GetMatchPattern()->OpType() - // || group_expr->Op()->GetExpType() != rule->GetMatchPattern()->ExpType(); + // group_expr->Node()->GetOpType() != rule->GetMatchPattern()->OpType() + // || group_expr->Node()->GetExpType() != rule->GetMatchPattern()->ExpType(); bool root_pattern_mismatch = - group_expr->Op()->GetOpType() != rule->GetMatchPattern()->Type(); + group_expr->Node()->GetOpType() != rule->GetMatchPattern()->Type(); bool already_explored = group_expr->HasRuleExplored(rule.get()); bool child_pattern_mismatch = group_expr->GetChildrenGroupsSize() != @@ -101,7 +101,7 @@ void OptimizeExpression::execute() { std::sort(valid_rules.begin(), valid_rules.end()); LOG_DEBUG("OptimizeExpression::execute() op %d, valid rules : %lu", - static_cast(group_expr_->Op()->GetOpType()), valid_rules.size()); + static_cast(group_expr_->Node()->GetOpType()), valid_rules.size()); // Apply rule for (auto &r : valid_rules) { PushTask(new ApplyRule(group_expr_, r.rule, context_)); @@ -184,14 +184,14 @@ void ApplyRule::execute() { continue; } - std::vector> after; + std::vector> after; rule_->Transform(before, after, context_.get()); for (auto &new_expr : after) { std::shared_ptr new_gexpr; if (context_->metadata->RecordTransformedExpression( new_expr, new_gexpr, group_expr_->GetGroupID())) { // A new group expression is generated - if (new_gexpr->Op()->IsLogical()) { + if (new_gexpr->Node()->IsLogical()) { // Derive stats for the *logical expression* PushTask(new DeriveStats(new_gexpr.get(), ExprSet{}, context_)); if (explore_only) { @@ -427,7 +427,7 @@ void TopDownRewrite::execute() { if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; + std::vector> after; r.rule->Transform(before, after, context_.get()); // Rewrite rule should provide at most 1 expression @@ -488,7 +488,7 @@ void BottomUpRewrite::execute() { if (iterator.HasNext()) { auto before = iterator.Next(); PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; + std::vector> after; r.rule->Transform(before, after, context_.get()); // Rewrite rule should provide at most 1 expression diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index c04c8025903..6c83121d03b 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -66,7 +66,7 @@ unique_ptr PlanGenerator::ConvertOpExpression( output_cols_ = move(output_cols); children_plans_ = move(children_plans); children_expr_map_ = move(children_expr_map); - op->Op().Accept(this); + op->Node()->Accept(this); BuildProjectionPlan(); output_plan_->SetCardinality(estimated_cardinality); return move(output_plan_); diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 4357d81ae2f..882be0a2687 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -22,7 +22,7 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { // auto root_type = match_pattern->OpType(); auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } if (IsPhysical()) return PHYS_PROMISE; diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 3138492f7d4..24a3b541164 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -41,7 +41,7 @@ InnerJoinCommutativity::InnerJoinCommutativity() { match_pattern->AddChild(right_child); } -bool InnerJoinCommutativity::Check(std::shared_ptr expr, +bool InnerJoinCommutativity::Check(std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; @@ -49,19 +49,21 @@ bool InnerJoinCommutativity::Check(std::shared_ptr expr, } void InnerJoinCommutativity::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto join_op = input->Op().As(); + auto join_op = input->Node()->As(); auto join_predicates = std::vector(join_op->join_predicates); + + auto result_plan = std::make_shared( LogicalInnerJoin::make(join_predicates)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); LOG_TRACE( "Reorder left child with op %s and right child with op %s for inner join", - children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); + children[0]->Node()->GetName().c_str(), children[1]->Node()->GetName().c_str()); result_plan->PushChild(children[1]); result_plan->PushChild(children[0]); @@ -86,7 +88,7 @@ InnerJoinAssociativity::InnerJoinAssociativity() { } // TODO: As far as I know, theres nothing else that needs to be checked -bool InnerJoinAssociativity::Check(std::shared_ptr expr, +bool InnerJoinAssociativity::Check(std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; @@ -94,29 +96,29 @@ bool InnerJoinAssociativity::Check(std::shared_ptr expr, } void InnerJoinAssociativity::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN // right) Variables are named accordingly to above transformation - auto parent_join = input->Op().As(); - std::vector> children = input->Children(); + auto parent_join = input->Node()->As(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - PELOTON_ASSERT(children[0]->Op().GetOpType() == OpType::InnerJoin); + PELOTON_ASSERT(children[0]->Node()->GetOpType() == OpType::InnerJoin); PELOTON_ASSERT(children[0]->Children().size() == 2); - auto child_join = children[0]->Op().As(); + auto child_join = children[0]->Node()->As(); auto left = children[0]->Children()[0]; auto middle = children[0]->Children()[1]; auto right = children[1]; LOG_DEBUG("Reordered join structured: (%s JOIN %s) JOIN %s", - left->Op().GetName().c_str(), middle->Op().GetName().c_str(), - right->Op().GetName().c_str()); + left->Node()->GetName().c_str(), middle->Node()->GetName().c_str(), + right->Node()->GetName().c_str()); // Get Alias sets auto &memo = context->metadata->memo; - auto middle_group_id = middle->Op().As()->origin_group; - auto right_group_id = right->Op().As()->origin_group; + auto middle_group_id = middle->Node()->As()->origin_group; + auto right_group_id = right->Node()->As()->origin_group; const auto &middle_group_aliases_set = memo.GetGroupByID(middle_group_id)->GetTableAliases(); @@ -154,14 +156,14 @@ void InnerJoinAssociativity::Transform( } // Construct new child join operator - std::shared_ptr new_child_join = + std::shared_ptr new_child_join = std::make_shared( LogicalInnerJoin::make(new_child_join_predicates)); new_child_join->PushChild(middle); new_child_join->PushChild(right); // Construct new parent join operator - std::shared_ptr new_parent_join = + std::shared_ptr new_parent_join = std::make_shared( LogicalInnerJoin::make(new_parent_join_predicates)); new_parent_join->PushChild(left); @@ -182,16 +184,16 @@ GetToDummyScan::GetToDummyScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToDummyScan::Check(std::shared_ptr plan, +bool GetToDummyScan::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); return get->table == nullptr; } void GetToDummyScan::Transform( - UNUSED_ATTRIBUTE std::shared_ptr input, - std::vector> &transformed, + UNUSED_ATTRIBUTE std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result_plan = std::make_shared(DummyScan::make()); @@ -206,24 +208,24 @@ GetToSeqScan::GetToSeqScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToSeqScan::Check(std::shared_ptr plan, +bool GetToSeqScan::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); return get->table != nullptr; } void GetToSeqScan::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalGet *get = input->Op().As(); + const LogicalGet *get = input->Node()->As(); auto result_plan = std::make_shared( PhysicalSeqScan::make(get->get_id, get->table, get->table_alias, get->predicates, get->is_for_update)); - UNUSED_ATTRIBUTE std::vector> children = + UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); @@ -238,12 +240,12 @@ GetToIndexScan::GetToIndexScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToIndexScan::Check(std::shared_ptr plan, +bool GetToIndexScan::Check(std::shared_ptr plan, OptimizeContext *context) const { // If there is a index for the table, return true, // else return false (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); bool index_exist = false; if (get != nullptr && get->table != nullptr && !get->table->GetIndexCatalogEntries().empty()) { @@ -253,14 +255,14 @@ bool GetToIndexScan::Check(std::shared_ptr plan, } void GetToIndexScan::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - UNUSED_ATTRIBUTE std::vector> children = + UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); - const LogicalGet *get = input->Op().As(); + const LogicalGet *get = input->Node()->As(); // Get sort columns if they are all base columns and all in asc order auto sort = context->required_prop->GetPropertyOfType(PropertyType::SORT); @@ -415,17 +417,17 @@ LogicalQueryDerivedGetToPhysical::LogicalQueryDerivedGetToPhysical() { } bool LogicalQueryDerivedGetToPhysical::Check( - std::shared_ptr expr, OptimizeContext *context) const { + std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; return true; } void LogicalQueryDerivedGetToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalQueryDerivedGet *get = input->Op().As(); + const LogicalQueryDerivedGet *get = input->Node()->As(); auto result_plan = std::make_shared(QueryDerivedScan::make( @@ -443,16 +445,16 @@ LogicalExternalFileGetToPhysical::LogicalExternalFileGetToPhysical() { } bool LogicalExternalFileGetToPhysical::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExternalFileGetToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const auto *get = input->Op().As(); + const auto *get = input->Node()->As(); auto result_plan = std::make_shared( ExternalFileScan::make(get->get_id, get->format, get->file_name, @@ -472,7 +474,7 @@ LogicalDeleteToPhysical::LogicalDeleteToPhysical() { match_pattern->AddChild(child); } -bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, +bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -480,10 +482,10 @@ bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, } void LogicalDeleteToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalDelete *delete_op = input->Op().As(); + const LogicalDelete *delete_op = input->Node()->As(); auto result = std::make_shared( PhysicalDelete::make(delete_op->target_table)); PELOTON_ASSERT(input->Children().size() == 1); @@ -500,7 +502,7 @@ LogicalUpdateToPhysical::LogicalUpdateToPhysical() { match_pattern->AddChild(child); } -bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, +bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -508,10 +510,10 @@ bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, } void LogicalUpdateToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalUpdate *update_op = input->Op().As(); + const LogicalUpdate *update_op = input->Node()->As(); auto result = std::make_shared( PhysicalUpdate::make(update_op->target_table, update_op->updates)); PELOTON_ASSERT(input->Children().size() != 0); @@ -528,7 +530,7 @@ LogicalInsertToPhysical::LogicalInsertToPhysical() { // match_pattern->AddChild(child); } -bool LogicalInsertToPhysical::Check(std::shared_ptr plan, +bool LogicalInsertToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -536,10 +538,11 @@ bool LogicalInsertToPhysical::Check(std::shared_ptr plan, } void LogicalInsertToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalInsert *insert_op = input->Op().As(); + const LogicalInsert *insert_op = input->Node()->As(); + std::cout << insert_op << " val" << std::endl; auto result = std::make_shared(PhysicalInsert::make( insert_op->target_table, insert_op->columns, insert_op->values)); PELOTON_ASSERT(input->Children().size() == 0); @@ -557,17 +560,17 @@ LogicalInsertSelectToPhysical::LogicalInsertSelectToPhysical() { } bool LogicalInsertSelectToPhysical::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; return true; } void LogicalInsertSelectToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalInsertSelect *insert_op = input->Op().As(); + const LogicalInsertSelect *insert_op = input->Node()->As(); auto result = std::make_shared( PhysicalInsertSelect::make(insert_op->target_table)); PELOTON_ASSERT(input->Children().size() == 1); @@ -585,20 +588,20 @@ LogicalGroupByToHashGroupBy::LogicalGroupByToHashGroupBy() { } bool LogicalGroupByToHashGroupBy::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = - plan->Op().As(); + plan->Node()->As(); return !agg_op->columns.empty(); } void LogicalGroupByToHashGroupBy::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalAggregateAndGroupBy *agg_op = - input->Op().As(); + input->Node()->As(); auto result = std::make_shared( PhysicalHashGroupBy::make(agg_op->columns, agg_op->having)); PELOTON_ASSERT(input->Children().size() == 1); @@ -616,17 +619,17 @@ LogicalAggregateToPhysical::LogicalAggregateToPhysical() { } bool LogicalAggregateToPhysical::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = - plan->Op().As(); + plan->Node()->As(); return agg_op->columns.empty(); } void LogicalAggregateToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result = std::make_shared(PhysicalAggregate::make()); PELOTON_ASSERT(input->Children().size() == 1); @@ -653,7 +656,7 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { return; } -bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, +bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -661,16 +664,16 @@ bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, } void InnerJoinToInnerNLJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + const LogicalInnerJoin *inner_join = input->Node()->As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; + auto left_group_id = children[0]->Node()->As()->origin_group; + auto right_group_id = children[1]->Node()->As()->origin_group; auto &left_group_alias = context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); auto &right_group_alias = @@ -714,7 +717,7 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { return; } -bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, +bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -722,16 +725,16 @@ bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, } void InnerJoinToInnerHashJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + const LogicalInnerJoin *inner_join = input->Node()->As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; + auto left_group_id = children[0]->Node()->As()->origin_group; + auto right_group_id = children[1]->Node()->As()->origin_group; auto &left_group_alias = context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); auto &right_group_alias = @@ -765,7 +768,7 @@ ImplementDistinct::ImplementDistinct() { match_pattern->AddChild(std::make_shared(OpType::Leaf)); } -bool ImplementDistinct::Check(std::shared_ptr plan, +bool ImplementDistinct::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -773,13 +776,13 @@ bool ImplementDistinct::Check(std::shared_ptr plan, } void ImplementDistinct::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { (void)context; auto result_plan = std::make_shared(PhysicalDistinct::make()); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -796,7 +799,7 @@ ImplementLimit::ImplementLimit() { match_pattern->AddChild(std::make_shared(OpType::Leaf)); } -bool ImplementLimit::Check(std::shared_ptr plan, +bool ImplementLimit::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -804,16 +807,16 @@ bool ImplementLimit::Check(std::shared_ptr plan, } void ImplementLimit::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { (void)context; - const LogicalLimit *limit_op = input->Op().As(); + const LogicalLimit *limit_op = input->Node()->As(); auto result_plan = std::make_shared( PhysicalLimit::make(limit_op->offset, limit_op->limit, limit_op->sort_exprs, limit_op->sort_ascending)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -830,23 +833,23 @@ LogicalExportToPhysicalExport::LogicalExportToPhysicalExport() { } bool LogicalExportToPhysicalExport::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExportToPhysicalExport::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const auto *export_op = input->Op().As(); + const auto *export_op = input->Node()->As(); auto result_plan = std::make_shared(PhysicalExportExternalFile::make( export_op->format, export_op->file_name, export_op->delimiter, export_op->quote, export_op->escape)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -874,27 +877,27 @@ PushFilterThroughJoin::PushFilterThroughJoin() { match_pattern->AddChild(child); } -bool PushFilterThroughJoin::Check(std::shared_ptr, +bool PushFilterThroughJoin::Check(std::shared_ptr, OptimizeContext *) const { return true; } void PushFilterThroughJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughJoin::Transform"); auto &memo = context->metadata->memo; auto join_op_expr = input->Children().at(0); auto &join_children = join_op_expr->Children(); - auto left_group_id = join_children[0]->Op().As()->origin_group; - auto right_group_id = join_children[1]->Op().As()->origin_group; + auto left_group_id = join_children[0]->Node()->As()->origin_group; + auto right_group_id = join_children[1]->Node()->As()->origin_group; const auto &left_group_aliases_set = memo.GetGroupByID(left_group_id)->GetTableAliases(); const auto &right_group_aliases_set = memo.GetGroupByID(right_group_id)->GetTableAliases(); - auto &predicates = input->Op().As()->predicates; + auto &predicates = input->Node()->As()->predicates; std::vector left_predicates; std::vector right_predicates; std::vector join_predicates; @@ -918,10 +921,10 @@ void PushFilterThroughJoin::Transform( // Construct join operator auto pre_join_predicate = - join_op_expr->Op().As()->join_predicates; + join_op_expr->Node()->As()->join_predicates; join_predicates.insert(join_predicates.end(), pre_join_predicate.begin(), pre_join_predicate.end()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalInnerJoin::make(join_predicates)); @@ -966,20 +969,20 @@ PushFilterThroughAggregation::PushFilterThroughAggregation() { match_pattern->AddChild(child); } -bool PushFilterThroughAggregation::Check(std::shared_ptr, +bool PushFilterThroughAggregation::Check(std::shared_ptr, OptimizeContext *) const { return true; } void PushFilterThroughAggregation::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughAggregation::Transform"); auto aggregation_op = - input->Children().at(0)->Op().As(); + input->Children().at(0)->Node()->As(); - auto &predicates = input->Op().As()->predicates; + auto &predicates = input->Node()->As()->predicates; std::vector embedded_predicates; std::vector pushdown_predicates; @@ -1000,7 +1003,7 @@ void PushFilterThroughAggregation::Transform( embedded_predicates.emplace_back(predicate); } auto groupby_cols = aggregation_op->columns; - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalAggregateAndGroupBy::make(groupby_cols, embedded_predicates)); @@ -1030,7 +1033,7 @@ CombineConsecutiveFilter::CombineConsecutiveFilter() { match_pattern->AddChild(child); } -bool CombineConsecutiveFilter::Check(std::shared_ptr plan, +bool CombineConsecutiveFilter::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1046,18 +1049,18 @@ bool CombineConsecutiveFilter::Check(std::shared_ptr plan, } void CombineConsecutiveFilter::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto child_filter = input->Children()[0]; - auto root_predicates = input->Op().As()->predicates; - auto &child_predicates = child_filter->Op().As()->predicates; + auto root_predicates = input->Node()->As()->predicates; + auto &child_predicates = child_filter->Node()->As()->predicates; root_predicates.insert(root_predicates.end(), child_predicates.begin(), child_predicates.end()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalFilter::make(root_predicates)); @@ -1077,7 +1080,7 @@ EmbedFilterIntoGet::EmbedFilterIntoGet() { match_pattern->AddChild(child); } -bool EmbedFilterIntoGet::Check(std::shared_ptr plan, +bool EmbedFilterIntoGet::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1085,14 +1088,14 @@ bool EmbedFilterIntoGet::Check(std::shared_ptr plan, } void EmbedFilterIntoGet::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto get = input->Children()[0]->Op().As(); + auto get = input->Children()[0]->Node()->As(); - auto predicates = input->Op().As()->predicates; + auto predicates = input->Node()->As()->predicates; - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalGet::make(get->get_id, predicates, get->table, get->table_alias, get->is_for_update)); @@ -1115,13 +1118,13 @@ int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); } -bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, +bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1133,16 +1136,16 @@ bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, } void MarkJoinToInnerJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("MarkJoinToInnerJoin::Transform"); - UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); + UNUSED_ATTRIBUTE auto mark_join = input->Node()->As(); auto &join_children = input->Children(); PELOTON_ASSERT(mark_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(LogicalInnerJoin::make()); output->PushChild(join_children[0]); @@ -1166,13 +1169,13 @@ int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); } -bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, +bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1184,16 +1187,16 @@ bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, } void SingleJoinToInnerJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("SingleJoinToInnerJoin::Transform"); - UNUSED_ATTRIBUTE auto single_join = input->Op().As(); + UNUSED_ATTRIBUTE auto single_join = input->Node()->As(); auto &join_children = input->Children(); PELOTON_ASSERT(single_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(LogicalInnerJoin::make()); output->PushChild(join_children[0]); @@ -1219,13 +1222,13 @@ int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); } -bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, +bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1239,22 +1242,22 @@ bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, } void PullFilterThroughMarkJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughMarkJoin::Transform"); - UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); + UNUSED_ATTRIBUTE auto mark_join = input->Node()->As(); auto &join_children = input->Children(); - auto filter = join_children[1]->Op(); + auto filter = join_children[1]->Node(); auto &filter_children = join_children[1]->Children(); PELOTON_ASSERT(mark_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(filter); - std::shared_ptr join = - std::make_shared(input->Op()); + std::shared_ptr join = + std::make_shared(input->Node()); join->PushChild(join_children[0]); join->PushChild(filter_children[0]); @@ -1280,14 +1283,14 @@ int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, (void)context; auto root_type = match_pattern->Type(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op()->GetOpType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); } bool PullFilterThroughAggregation::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1300,18 +1303,18 @@ bool PullFilterThroughAggregation::Check( } void PullFilterThroughAggregation::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughAggregation::Transform"); auto &memo = context->metadata->memo; auto &filter_expr = input->Children()[0]; auto child_group_id = - filter_expr->Children()[0]->Op().As()->origin_group; + filter_expr->Children()[0]->Node()->As()->origin_group; const auto &child_group_aliases_set = memo.GetGroupByID(child_group_id)->GetTableAliases(); - auto &predicates = filter_expr->Op().As()->predicates; + auto &predicates = filter_expr->Node()->As()->predicates; std::vector correlated_predicates; std::vector normal_predicates; @@ -1337,15 +1340,15 @@ void PullFilterThroughAggregation::Transform( return; } - auto aggregation = input->Op().As(); + auto aggregation = input->Node()->As(); for (auto &col : aggregation->columns) { new_groupby_cols.emplace_back(col->Copy()); } - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalFilter::make(correlated_predicates)); std::vector new_having(aggregation->having); - std::shared_ptr new_aggregation = + std::shared_ptr new_aggregation = std::make_shared( LogicalAggregateAndGroupBy::make(new_groupby_cols, new_having)); output->PushChild(new_aggregation); @@ -1353,7 +1356,7 @@ void PullFilterThroughAggregation::Transform( // Construct child filter if any if (!normal_predicates.empty()) { - std::shared_ptr new_filter = + std::shared_ptr new_filter = std::make_shared( LogicalFilter::make(normal_predicates)); new_aggregation->PushChild(new_filter); diff --git a/src/optimizer/stats/child_stats_deriver.cpp b/src/optimizer/stats/child_stats_deriver.cpp index 5831dfdaffb..2dc50a22e9d 100644 --- a/src/optimizer/stats/child_stats_deriver.cpp +++ b/src/optimizer/stats/child_stats_deriver.cpp @@ -27,7 +27,7 @@ vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, gexpr_ = gexpr; memo_ = memo; output_ = vector(gexpr->GetChildrenGroupsSize(), ExprSet{}); - gexpr->Op()->Accept(this); + gexpr->Node()->Accept(this); return std::move(output_); } diff --git a/src/optimizer/stats/stats_calculator.cpp b/src/optimizer/stats/stats_calculator.cpp index b327aa2df8b..0e4e09a573b 100644 --- a/src/optimizer/stats/stats_calculator.cpp +++ b/src/optimizer/stats/stats_calculator.cpp @@ -33,7 +33,7 @@ void StatsCalculator::CalculateStats(GroupExpression *gexpr, memo_ = memo; required_cols_ = required_cols; txn_ = txn; - gexpr->Op()->Accept(this); + gexpr->Node()->Accept(this); } void StatsCalculator::Visit(const LogicalGet *op) { diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 23f520596dc..d2b806ef91b 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -61,7 +61,7 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { EXPECT_TRUE(rule.Check(join, nullptr)); - std::vector> outputs; + std::vector> outputs; rule.Transform(join, outputs, nullptr); EXPECT_EQ(outputs.size(), 1); @@ -139,17 +139,17 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); EXPECT_EQ(1, - parent_join->Op().As()->join_predicates.size()); + parent_join->Node()->As()->join_predicates.size()); EXPECT_EQ(1, parent_join->Children()[0] - ->Op() - .As() + ->Node() + ->As() ->join_predicates.size()); // Setup rule InnerJoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); - std::vector> outputs; + std::vector> outputs; rule.Transform(parent_join, outputs, root_context); EXPECT_EQ(1, outputs.size()); @@ -159,8 +159,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Node()->As(); + auto child_join_op = output_join->Children()[1]->Node()->As(); EXPECT_EQ(2, parent_join_op->join_predicates.size()); EXPECT_EQ(0, child_join_op->join_predicates.size()); delete root_context; @@ -234,17 +234,17 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); EXPECT_EQ(2, - parent_join->Op().As()->join_predicates.size()); + parent_join->Node()->As()->join_predicates.size()); EXPECT_EQ(0, parent_join->Children()[0] - ->Op() - .As() + ->Node() + ->As() ->join_predicates.size()); // Setup rule InnerJoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); - std::vector> outputs; + std::vector> outputs; rule.Transform(parent_join, outputs, root_context); EXPECT_EQ(1, outputs.size()); @@ -254,8 +254,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Node()->As(); + auto child_join_op = output_join->Children()[1]->Node()->As(); EXPECT_EQ(1, parent_join_op->join_predicates.size()); EXPECT_EQ(1, child_join_op->join_predicates.size()); delete root_context; diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 242d62896dc..c1247baeed6 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -368,28 +368,28 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op()->GetOpType()); - auto join_op = group_expr->Op()->As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Node()->GetOpType()); + auto join_op = group_expr->Node()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); - auto get_op = l_group_expr->Op()->As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + auto get_op = l_group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Op()->GetOpType()); - auto filter_op = r_group_expr->Op()->As(); + EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Node()->GetOpType()); + auto filter_op = r_group_expr->Node()->As(); EXPECT_EQ(1, filter_op->predicates.size()); EXPECT_TRUE(filter_op->predicates[0].expr->ExactlyEquals(*predicates[1])); // Check get below filter group_expr = GetSingleGroupExpression(memo, r_group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); - get_op = group_expr->Op()->As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + get_op = group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); txn_manager.CommitTransaction(txn); @@ -456,21 +456,21 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op()->GetOpType()); - auto join_op = group_expr->Op()->As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Node()->GetOpType()); + auto join_op = group_expr->Node()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op()->GetOpType()); - auto get_op = l_group_expr->Op()->As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + auto get_op = l_group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::Get, r_group_expr->Op()->GetOpType()); - get_op = r_group_expr->Op()->As(); + EXPECT_EQ(OpType::Get, r_group_expr->Node()->GetOpType()); + get_op = r_group_expr->Node()->As(); EXPECT_EQ(1, get_op->predicates.size()); EXPECT_TRUE(get_op->predicates[0].expr->ExactlyEquals(*predicates[1])); From 5caae6d0b66d3295158dbdb84909920c04e220e8 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Sun, 5 May 2019 06:40:35 +0000 Subject: [PATCH 07/14] Ported equivalent transforms and other rules onto AbstractNodes --- src/include/common/internal_types.h | 36 +- src/include/expression/abstract_expression.h | 18 + .../expression/group_marker_expression.h | 64 +++ src/include/optimizer/absexpr_expression.h | 171 +++--- src/include/optimizer/abstract_node.h | 1 + src/include/optimizer/binding.h | 9 +- .../optimizer/child_property_deriver.h | 9 +- .../cost_model/abstract_cost_model.h | 5 +- src/include/optimizer/group.h | 26 +- src/include/optimizer/group_expression.h | 12 - src/include/optimizer/memo.h | 44 +- src/include/optimizer/optimize_context.h | 6 +- src/include/optimizer/optimizer.h | 19 +- src/include/optimizer/optimizer_metadata.h | 2 +- src/include/optimizer/optimizer_task.h | 148 +++-- src/include/optimizer/optimizer_task_pool.h | 16 +- src/include/optimizer/pattern.h | 15 +- src/include/optimizer/property_enforcer.h | 8 +- src/include/optimizer/rewriter.h | 21 +- src/include/optimizer/rule.h | 3 +- src/include/optimizer/rule_rewrite.h | 48 +- .../optimizer/stats/child_stats_deriver.h | 9 +- .../optimizer/stats/stats_calculator.h | 9 +- src/include/traffic_cop/traffic_cop.h | 3 + src/optimizer/absexpr_expression.cpp | 146 +++++ src/optimizer/binding.cpp | 39 +- src/optimizer/memo.cpp | 17 +- src/optimizer/optimizer_task.cpp | 222 ++++---- src/optimizer/pattern.cpp | 18 +- src/optimizer/rewriter.cpp | 129 +++-- src/optimizer/rule.cpp | 63 ++- src/optimizer/rule_impls.cpp | 8 +- src/optimizer/rule_rewrite.cpp | 404 ++++++++++++-- src/traffic_cop/traffic_cop.cpp | 28 +- test/include/optimizer/mock_task.h | 4 +- test/optimizer/absexpr_test.cpp | 460 ++++++++++++++++ test/optimizer/optimizer_test.cpp | 10 - test/optimizer/rewriter_test.cpp | 88 +-- test/optimizer/rule_rewrite_test.cpp | 521 ++++++++++++++++++ 39 files changed, 2191 insertions(+), 668 deletions(-) create mode 100644 src/include/expression/group_marker_expression.h create mode 100644 src/optimizer/absexpr_expression.cpp create mode 100644 test/optimizer/absexpr_test.cpp create mode 100644 test/optimizer/rule_rewrite_test.cpp diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 21de29a080e..39c9647b2ef 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1,15 +1,3 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// internal_types.h -// -// Identification: src/include/common/internal_types.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // // Peloton @@ -258,7 +246,12 @@ enum class ExpressionType { // ----------------------------- // Miscellaneous // ----------------------------- - CAST = 600 + CAST = 600, + + // ----------------------------- + // Rewriter-specific identifier + // ----------------------------- + GROUP_MARKER = 721 }; // When short_str is true, return a short version of ExpressionType string @@ -1384,8 +1377,21 @@ enum class RuleType : uint32_t { PULL_FILTER_THROUGH_AGGREGATION, // AST rewrite rules (logical -> logical) - // Removes ConstantValueExpression = ConstantValueExpression - COMP_EQUALITY_ELIMINATION, + // Removes ConstantValue =/!=//<=/>= ConstantValue + CONSTANT_COMPARE_EQUAL, + CONSTANT_COMPARE_NOTEQUAL, + CONSTANT_COMPARE_LESSTHAN, + CONSTANT_COMPARE_GREATERTHAN, + CONSTANT_COMPARE_LESSTHANOREQUALTO, + CONSTANT_COMPARE_GREATERTHANOREQUALTO, + + // Logical equivalent + EQUIV_AND, + EQUIV_OR, + EQUIV_COMPARE_EQUAL, + + TV_EQUALITY_WITH_TWO_CV, // (A.B = x) AND (A.B = y) where x/y are constant + TRANSITIVE_CLOSURE_CONSTANT, // (A.B = x) AND (A.B = C.D) // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/expression/abstract_expression.h b/src/include/expression/abstract_expression.h index 6acdf7b2751..5cfd3614cb8 100644 --- a/src/include/expression/abstract_expression.h +++ b/src/include/expression/abstract_expression.h @@ -111,6 +111,24 @@ class AbstractExpression : public Printable { children_[index].reset(expr); } + void ClearChildren() { + // The rewriter copies the AbstractExpression presented to the rewriter. + // This function is used by the rewriter to properly wipe all the children + // of AbstractExpression once the AbstractExpression tree has been converted + // to the intermediary AbsExpr_Container/Expression tree. + // + // This function should only be invoked on copied AbstractExpressions and + // never on original ones passed into the [rewriter] otherwise we would + // violate immutable properties. We do not believe this function is strictly + // necessary in terrier, however this function does serve some usefulness. + // + // This allows us to reduce our memory footprint while also providing better + // implementation constraints within the rewriter (i.e: the rewriter should + // not be trying to operate directly on AbstractExpression but rather on the + // intermediary representation). + children_.clear(); + } + void SetExpressionType(ExpressionType type) { exp_type_ = type; } ////////////////////////////////////////////////////////////////////////////// diff --git a/src/include/expression/group_marker_expression.h b/src/include/expression/group_marker_expression.h new file mode 100644 index 00000000000..25c717f16e5 --- /dev/null +++ b/src/include/expression/group_marker_expression.h @@ -0,0 +1,64 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// group_marker_expression.h +// +// Identification: src/include/expression/group_marker_expression.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "expression/abstract_expression.h" +#include "optimizer/group_expression.h" +#include "util/hash_util.h" + +namespace peloton { + +namespace executor { +class ExecutorContext; +} // namespace executor + +namespace expression { + +//===----------------------------------------------------------------------===// +// GroupMarkerExpression +//===----------------------------------------------------------------------===// + +class GroupMarkerExpression : public AbstractExpression { + public: + GroupMarkerExpression(optimizer::GroupID group_id) : + AbstractExpression(ExpressionType::GROUP_MARKER), + group_id_(group_id) {}; + + optimizer::GroupID GetGroupID() { return group_id_; } + + AbstractExpression *Copy() const override { + return new GroupMarkerExpression(group_id_); + } + + type::Value Evaluate(const AbstractTuple *tuple1, + const AbstractTuple *tuple2, + executor::ExecutorContext *context) const { + (void)tuple1; + (void)tuple2; + (void)context; + PELOTON_ASSERT(0); + } + + void Accept(SqlNodeVisitor *) { + PELOTON_ASSERT(0); + } + + protected: + optimizer::GroupID group_id_; + + GroupMarkerExpression(const GroupMarkerExpression &other) + : AbstractExpression(other), group_id_(other.group_id_) {} +}; + +} // namespace expression +} // namespace peloton diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h index 745881ccfb0..d5c6f098be7 100644 --- a/src/include/optimizer/absexpr_expression.h +++ b/src/include/optimizer/absexpr_expression.h @@ -10,7 +10,8 @@ #pragma once -// AbstractExpression Definition +#include "optimizer/abstract_node_expression.h" +#include "optimizer/abstract_node.h" #include "expression/abstract_expression.h" #include "expression/conjunction_expression.h" #include "expression/comparison_expression.h" @@ -22,41 +23,45 @@ namespace peloton { namespace optimizer { -// (TODO): rethink the AbsExpr_Container/Expression approach in comparion to abstract -// Most of the core rule/optimizer code relies on the concept of an Operator / -// OperatorExpression and the interface that the two functions respectively expose. +// AbsExpr_Container and AbsExpr_Expression provides and serves an analogous purpose +// to Operator and OperatorExpression. Each AbsExpr_Container wraps a single +// AbstractExpression node with the children placed inside the AbsExpr_Expression. // -// The annoying part is that an AbstractExpression blends together an Operator -// and OperatorExpression. Second part, the AbstractExpression does not export the -// correct interface that the rest of the system depends on. -// -// As an extreme level of simplification (sort of hacky), an AbsExpr_Container is -// analogous to Operator and wraps a single AbstractExpression node. AbsExpr_Expression -// is analogous to OperatorExpression. -// -// AbsExpr_Container does *not* handle memory correctly w.r.t internal instantiations -// from Rule transformation. This is since Peloton itself mixes unique_ptrs and -// hands out raw pointers which makes adding a shared_ptr here extremely problematic. -// terrier uses only shared_ptr when dealing with AbstractExpression trees. - -class AbsExpr_Container { +// This is done to export the correct interface from the wrapped AbstractExpression +// to the rest of the core rule/optimizer code/logic. +class AbsExpr_Container: public AbstractNode { public: - AbsExpr_Container(); + // Default constructors + AbsExpr_Container() = default; + AbsExpr_Container(const AbsExpr_Container &other): + AbstractNode() { + expr = other.expr; + } - AbsExpr_Container(const expression::AbstractExpression *expr) { - node = expr; + AbsExpr_Container(std::shared_ptr expr_) { + expr = expr_; + } + + OpType GetOpType() const { + return OpType::Undefined; } // Return operator type - ExpressionType GetType() const { + ExpressionType GetExpType() const { if (IsDefined()) { - return node->GetExpressionType(); + return expr->GetExpressionType(); } return ExpressionType::INVALID; } - const expression::AbstractExpression *GetExpr() const { - return node; + std::shared_ptr GetExpr() const { + return expr; + } + + // Dummy Accept + void Accept(OperatorVisitor *v) const { + (void)v; + PELOTON_ASSERT(0); } // Operator contains Logical node @@ -71,7 +76,7 @@ class AbsExpr_Container { std::string GetName() const { if (IsDefined()) { - return node->GetExpressionName(); + return expr->GetExpressionName(); } return "Undefined"; @@ -79,30 +84,31 @@ class AbsExpr_Container { hash_t Hash() const { if (IsDefined()) { - return node->Hash(); + return expr->Hash(); } return 0; } + bool operator==(const AbstractNode &r) { + if (r.GetExpType() != ExpressionType::INVALID) { + const AbsExpr_Container &cnt = dynamic_cast(r); + return (*this == cnt); + } + + return false; + } + bool operator==(const AbsExpr_Container &r) { if (IsDefined() && r.IsDefined()) { - // (TODO): need a better way to determine deep equality - - // NOTE: - // Without proper equality determinations, the groups will - // not be assigned correctly. Arguably, terrier does this - // better because a blind ExactlyEquals on different types - // of ConstantValueExpression under Peloton will crash! - - // For now, just return (false). - // I don't anticipate this will affect correctness, just - // performance, since duplicate trees will have to evaluated - // over and over again, rather than being able to "borrow" - // a previous tree's rewrite. - // - // Probably not worth to create a "validator" since porting - // this to terrier anyways (?). == does not check Value - // so it's broken. ExactlyEqual requires precondition checking. + //TODO(): proper equality check when migrate to terrier + // Equality check relies on performing the following: + // - Check each node's ExpressionType + // - Check other parameters for a given node + // We believe that in terrier so long as the AbstractExpression + // are children-less, operator== provides sufficient checking. + // The reason behind why the children-less guarantee is required, + // is that the "real" children are actually tracked by the + // AbsExpr_Expression class. return false; } else if (!IsDefined() && !r.IsDefined()) { return true; @@ -112,52 +118,31 @@ class AbsExpr_Container { // Operator contains physical or logical operator node bool IsDefined() const { - return node != nullptr; - } - - //(TODO): fix memory management once go to terrier - expression::AbstractExpression *Rebuild(std::vector children) { - switch (GetType()) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_LIKE: - case ExpressionType::COMPARE_NOTLIKE: - case ExpressionType::COMPARE_IN: - case ExpressionType::COMPARE_DISTINCT_FROM: { - PELOTON_ASSERT(children.size() == 2); - return new expression::ComparisonExpression(GetType(), children[0], children[1]); - } - case ExpressionType::CONJUNCTION_AND: - case ExpressionType::CONJUNCTION_OR: { - PELOTON_ASSERT(children.size() == 2); - return new expression::ConjunctionExpression(GetType(), children[0], children[1]); - } - case ExpressionType::VALUE_CONSTANT: { - PELOTON_ASSERT(children.size() == 0); - auto cve = static_cast(node); - return new expression::ConstantValueExpression(cve->GetValue()); - } - default: { - int type = static_cast(GetType()); - LOG_ERROR("Unimplemented Rebuild() for %d found", type); - return nullptr; - } - } + return expr != nullptr; } + //TODO(): Function should use std::shared_ptr when migrate to terrier + expression::AbstractExpression *CopyWithChildren(std::vector children); + private: - const expression::AbstractExpression *node; + // Internal wrapped AbstractExpression + std::shared_ptr expr; }; -class AbsExpr_Expression { + +class AbsExpr_Expression: public AbstractNodeExpression { public: - AbsExpr_Expression(AbsExpr_Container op): op(op) {}; + AbsExpr_Expression(std::shared_ptr n) { + std::shared_ptr cnt = std::dynamic_pointer_cast(n); + PELOTON_ASSERT(cnt != nullptr); + + node = n; + } - void PushChild(std::shared_ptr op) { + // Disallow copy and move constructor + DISALLOW_COPY_AND_MOVE(AbsExpr_Expression); + + void PushChild(std::shared_ptr op) { children.push_back(op); } @@ -165,19 +150,27 @@ class AbsExpr_Expression { children.pop_back(); } - const std::vector> &Children() const { + const std::vector> &Children() const { return children; } - const AbsExpr_Container &Op() const { - return op; + const std::shared_ptr Node() const { + // Integrity constraint + std::shared_ptr cnt = std::dynamic_pointer_cast(node); + PELOTON_ASSERT(cnt != nullptr); + + return node; + } + + const std::string GetInfo() const { + //TODO(): create proper info statement? + return ""; } private: - AbsExpr_Container op; - std::vector> children; + std::shared_ptr node; + std::vector> children; }; } // namespace optimizer } // namespace peloton - diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h index 2b5b5f40f4b..3b4a1cb16b1 100644 --- a/src/include/optimizer/abstract_node.h +++ b/src/include/optimizer/abstract_node.h @@ -83,6 +83,7 @@ enum class OpType { class OperatorVisitor; struct AbstractNode { + AbstractNode() {} AbstractNode(std::shared_ptr node) : node(node) {} ~AbstractNode() {} diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index d2097eaeff0..233e27f3aea 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -46,9 +46,7 @@ class BindingIterator { class GroupBindingIterator : public BindingIterator { public: - // TODO(ncx): pattern - GroupBindingIterator(Memo& memo, GroupID id, - std::shared_ptr pattern); + GroupBindingIterator(Memo& memo, GroupID id, std::shared_ptr pattern); bool HasNext() override; @@ -69,16 +67,13 @@ class GroupBindingIterator : public BindingIterator { class GroupExprBindingIterator : public BindingIterator { public: - // TODO(ncx): pattern - GroupExprBindingIterator(Memo& memo, GroupExpression *gexpr, - std::shared_ptr pattern); + GroupExprBindingIterator(Memo& memo, GroupExpression *gexpr, std::shared_ptr pattern); bool HasNext() override; std::shared_ptr Next() override; private: - // TODO(ncx): pattern GroupExpression* gexpr_; std::shared_ptr pattern_; diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index 6ec2c09400a..9a64c6af4e7 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -18,7 +18,6 @@ namespace peloton { namespace optimizer { -template class Memo; } @@ -36,9 +35,9 @@ class ChildPropertyDeriver : public OperatorVisitor { std::vector, std::vector>>> - GetProperties(GroupExpression *gexpr, + GetProperties(GroupExpression *gexpr, std::shared_ptr requirements, - Memo *memo); + Memo *memo); void Visit(const DummyScan *) override; void Visit(const PhysicalSeqScan *) override; @@ -78,8 +77,8 @@ class ChildPropertyDeriver : public OperatorVisitor { * @brief We need the memo and gexpr because some property may depend on * child's schema */ - Memo *memo_; - GroupExpression *gexpr_; + GroupExpression *gexpr_; + Memo *memo_; }; } // namespace optimizer diff --git a/src/include/optimizer/cost_model/abstract_cost_model.h b/src/include/optimizer/cost_model/abstract_cost_model.h index e01548739b1..0a57be183d7 100644 --- a/src/include/optimizer/cost_model/abstract_cost_model.h +++ b/src/include/optimizer/cost_model/abstract_cost_model.h @@ -18,7 +18,6 @@ namespace peloton { namespace optimizer { -template class Memo; // Default cost when cost model cannot compute correct cost. @@ -36,8 +35,8 @@ static constexpr double DEFAULT_OPERATOR_COST = 0.0025; class AbstractCostModel : public OperatorVisitor { public: - virtual double CalculateCost(GroupExpression *gexpr, - Memo *memo, + virtual double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) = 0; }; diff --git a/src/include/optimizer/group.h b/src/include/optimizer/group.h index 9129a4952a8..e2f24ca953a 100644 --- a/src/include/optimizer/group.h +++ b/src/include/optimizer/group.h @@ -32,7 +32,6 @@ class ColumnStats; //===--------------------------------------------------------------------===// // Group //===--------------------------------------------------------------------===// -template class Group : public Printable { public: Group(GroupID id, std::unordered_set table_alias); @@ -40,30 +39,30 @@ class Group : public Printable { // If the GroupExpression is generated by applying a // property enforcer, we add them to enforced_exprs_ // which will not be enumerated during OptimizeExpression - void AddExpression(std::shared_ptr> expr, + void AddExpression(std::shared_ptr expr, bool enforced); void RemoveLogicalExpression(size_t idx) { logical_expressions_.erase(logical_expressions_.begin() + idx); } - bool SetExpressionCost(GroupExpression *expr, double cost, + bool SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties); - GroupExpression *GetBestExpression(std::shared_ptr &properties); + GroupExpression *GetBestExpression(std::shared_ptr &properties); inline const std::unordered_set &GetTableAliases() const { return table_aliases_; } // TODO: thread safety? - const std::vector>> GetLogicalExpressions() + const std::vector> GetLogicalExpressions() const { return logical_expressions_; } // TODO: thread safety? - const std::vector>> GetPhysicalExpressions() + const std::vector> GetPhysicalExpressions() const { return physical_expressions_; } @@ -100,14 +99,17 @@ class Group : public Printable { // This is called in rewrite phase to erase the only logical expression in the // group inline void EraseLogicalExpression() { - PELOTON_ASSERT(logical_expressions_.size() == 1); + // During query rewriting (pre-optimizer), the rewriter can execute in a scenario + // where a group can have multiple logical expressions (due to AND/OR/= equivalence). + // TODO(): refine these assertions to distinguish between optimizer/rewrite stages + PELOTON_ASSERT(logical_expressions_.size() >= 1); PELOTON_ASSERT(physical_expressions_.size() == 0); logical_expressions_.clear(); } // This should only be called in rewrite phase to retrieve the only logical // expr in the group - inline GroupExpression *GetLogicalExpression() { + inline GroupExpression *GetLogicalExpression() { PELOTON_ASSERT(logical_expressions_.size() == 1); PELOTON_ASSERT(physical_expressions_.size() == 0); return logical_expressions_[0].get(); @@ -119,15 +121,15 @@ class Group : public Printable { // TODO(boweic) Do not use string, store table alias id std::unordered_set table_aliases_; std::unordered_map, - std::tuple *>, PropSetPtrHash, + std::tuple, PropSetPtrHash, PropSetPtrEq> lowest_cost_expressions_; // Whether equivalent logical expressions have been explored for this group bool has_explored_; - std::vector>> logical_expressions_; - std::vector>> physical_expressions_; - std::vector>> enforced_exprs_; + std::vector> logical_expressions_; + std::vector> physical_expressions_; + std::vector> enforced_exprs_; // We'll add stats lazily // TODO(boweic): diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index fa79046b69a..a39c9471e59 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -32,7 +32,6 @@ using GroupID = int32_t; //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// -template class GroupExpression { public: GroupExpression(std::shared_ptr node, std::vector child_groups); @@ -90,14 +89,3 @@ class GroupExpression { } // namespace optimizer } // namespace peloton - -namespace std { - -template <> -struct hash { - typedef peloton::optimizer::GroupExpression argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const &s) const { return s.Hash(); } -}; - -} // namespace std diff --git a/src/include/optimizer/memo.h b/src/include/optimizer/memo.h index 4bc77009de8..4ad1633ae4d 100644 --- a/src/include/optimizer/memo.h +++ b/src/include/optimizer/memo.h @@ -22,15 +22,12 @@ namespace peloton { namespace optimizer { -template struct GExprPtrHash { - std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } + std::size_t operator()(GroupExpression* const& s) const { return s->Hash(); } }; -template struct GExprPtrEq { - bool operator()(GroupExpression* const& t1, - GroupExpression* const& t2) const { + bool operator()(GroupExpression* const& t1, GroupExpression* const& t2) const { return *t1 == *t2; } }; @@ -38,7 +35,6 @@ struct GExprPtrEq { //===--------------------------------------------------------------------===// // Memo //===--------------------------------------------------------------------===// -template class Memo { public: Memo(); @@ -51,17 +47,13 @@ class Memo { * target_group: an optional target group to insert expression into * return: existing expression if found. Otherwise, return the new expr */ - GroupExpression* InsertExpression( - std::shared_ptr> gexpr, - bool enforced); + GroupExpression* InsertExpression(std::shared_ptr gexpr, bool enforced); - GroupExpression* InsertExpression( - std::shared_ptr> gexpr, - GroupID target_group, bool enforced); + GroupExpression* InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced); - std::vector>>& Groups(); + std::vector>& Groups(); - Group* GetGroupByID(GroupID id); + Group* GetGroupByID(GroupID id); const std::string GetInfo(int num_indent) const; const std::string GetInfo() const; @@ -73,34 +65,30 @@ class Memo { //===--------------------------------------------------------------------===// // For rewrite phase: remove and add expression directly for the set //===--------------------------------------------------------------------===// - void RemoveParExpressionForRewirte(GroupExpression* gexpr) { + void RemoveParExpressionForRewirte(GroupExpression* gexpr) { group_expressions_.erase(gexpr); } - void AddParExpressionForRewrite(GroupExpression* gexpr) { + void AddParExpressionForRewrite(GroupExpression* gexpr) { group_expressions_.insert(gexpr); } // When a rewrite rule is applied, we need to replace the original gexpr with // a new one, which reqires us to first remove the original gexpr from the // memo void EraseExpression(GroupID group_id) { - auto gexpr = groups_[group_id]->GetLogicalExpression(); - group_expressions_.erase(gexpr); + std::vector> gexprs = groups_[group_id]->GetLogicalExpressions(); + for (auto gexpr : gexprs) { + group_expressions_.erase(gexpr.get()); + } + groups_[group_id]->EraseLogicalExpression(); } private: - GroupID AddNewGroup(std::shared_ptr> gexpr); - - // Internal InsertExpression function - GroupExpression* InsertExpr( - std::shared_ptr> gexpr, - GroupID target_group, bool enforced); + GroupID AddNewGroup(std::shared_ptr gexpr); // The group owns the group expressions, not the memo - std::unordered_set*, - GExprPtrHash, - GExprPtrEq> group_expressions_; - std::vector>> groups_; + std::unordered_set group_expressions_; + std::vector> groups_; size_t rule_set_size_; }; diff --git a/src/include/optimizer/optimize_context.h b/src/include/optimizer/optimize_context.h index 15747a44b5a..b5568208d9e 100644 --- a/src/include/optimizer/optimize_context.h +++ b/src/include/optimizer/optimize_context.h @@ -22,20 +22,18 @@ namespace peloton { namespace optimizer { -template class OptimizerMetadata; -template class OptimizeContext { public: - OptimizeContext(OptimizerMetadata *metadata, + OptimizeContext(OptimizerMetadata *metadata, std::shared_ptr required_prop, double cost_upper_bound = std::numeric_limits::max()) : metadata(metadata), required_prop(required_prop), cost_upper_bound(cost_upper_bound) {} - OptimizerMetadata *metadata; + OptimizerMetadata *metadata; std::shared_ptr required_prop; double cost_upper_bound; }; diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index 668049b5333..93f1ddf1e76 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -60,10 +60,7 @@ enum CostModels {DEFAULT, POSTGRES, TRIVIAL}; // Optimizer //===--------------------------------------------------------------------===// class Optimizer : public AbstractOptimizer { - template friend class BindingIterator; - - template friend class GroupBindingIterator; friend class ::peloton::test:: @@ -88,18 +85,18 @@ class Optimizer : public AbstractOptimizer { void Reset() override; - OptimizerMetadata &GetMetadata() { return metadata_; } + OptimizerMetadata &GetMetadata() { return metadata_; } /* For test purposes only */ - std::shared_ptr> TestInsertQueryTree( + std::shared_ptr TestInsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn) { return InsertQueryTree(tree, txn); } /* For test purposes only */ - void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr> root_context) { + void TestExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr root_context) { return ExecuteTaskStack(task_stack, root_group_id, root_context); } @@ -124,7 +121,7 @@ class Optimizer : public AbstractOptimizer { * tree: a peloton query tree representing a select query * return: the root group expression for the inserted query */ - std::shared_ptr> InsertQueryTree( + std::shared_ptr InsertQueryTree( parser::SQLStatement *tree, concurrency::TransactionContext *txn); /* GetQueryTreeRequiredProperties - get the required physical properties for @@ -166,12 +163,12 @@ class Optimizer : public AbstractOptimizer { * root_context: the OptimizerContext to use that maintains required *properties */ - void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, - std::shared_ptr> root_context); + void ExecuteTaskStack(OptimizerTaskStack &task_stack, int root_group_id, + std::shared_ptr root_context); ////////////////////////////////////////////////////////////////////////////// /// Metadata - OptimizerMetadata metadata_; + OptimizerMetadata metadata_; std::unique_ptr cost_model_; }; diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 6732c90199c..57dcb2ec7d8 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -39,7 +39,7 @@ class OptimizerMetadata { settings::SettingId::task_execution_timeout)), timer(Timer()) {} - Memo OperatorType, OperatorExpr> memo; + Memo memo; RuleSet rule_set; OptimizerTaskPool *task_pool; std::unique_ptr cost_model; diff --git a/src/include/optimizer/optimizer_task.h b/src/include/optimizer/optimizer_task.h index 173c64075c6..e02945e21b5 100644 --- a/src/include/optimizer/optimizer_task.h +++ b/src/include/optimizer/optimizer_task.h @@ -2,7 +2,7 @@ // // Peloton // -// rule.h +// optimizer_task.h // // Identification: src/include/optimizer/optimizer_task.h // @@ -14,6 +14,7 @@ #include #include +#include #include "expression/abstract_expression.h" #include "common/internal_types.h" @@ -24,28 +25,13 @@ class AbstractExpression; } namespace optimizer { -template class OptimizeContext; - -template class Memo; - -template class Rule; - -template struct RuleWithPromise; - -template class RuleSet; - -template class Group; - -template class GroupExpression; - -template class OptimizerMetadata; enum class OpType; @@ -72,10 +58,9 @@ enum class OptimizerTaskType { /** * @brief The base class for tasks in the optimizer */ -template class OptimizerTask { public: - OptimizerTask(std::shared_ptr> context, + OptimizerTask(std::shared_ptr context, OptimizerTaskType type) : type_(type), context_(context) {} @@ -91,24 +76,24 @@ class OptimizerTask { * @param valid_rules The valid rules to apply in the current rule set will be * append to valid_rules, with their promises */ - static void ConstructValidRules(GroupExpression *group_expr, - OptimizeContext *context, - std::vector>> &rules, - std::vector> &valid_rules); + static void ConstructValidRules(GroupExpression *group_expr, + OptimizeContext *context, + std::vector> &rules, + std::vector &valid_rules); virtual void execute() = 0; - void PushTask(OptimizerTask *task); + void PushTask(OptimizerTask *task); - inline Memo &GetMemo() const; + inline Memo &GetMemo() const; - inline RuleSet &GetRuleSet() const; + inline RuleSet &GetRuleSet() const; virtual ~OptimizerTask(){}; protected: OptimizerTaskType type_; - std::shared_ptr> context_; + std::shared_ptr context_; }; /** @@ -116,16 +101,15 @@ class OptimizerTask { * equivalent operator trees if not already explored 2. Cost all physical * operator trees given the current context */ -class OptimizeGroup : public OptimizerTask { +class OptimizeGroup : public OptimizerTask { public: - OptimizeGroup(Group *group, - std::shared_ptr> context) + OptimizeGroup(Group *group, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -135,32 +119,30 @@ class OptimizeGroup : public OptimizerTask { * promises so that a physical transformation rule is applied before a logical * transformation rule */ -class OptimizeExpression : public OptimizerTask { +class OptimizeExpression : public OptimizerTask { public: - OptimizeExpression(GroupExpression *group_expr, - std::shared_ptr> context) + OptimizeExpression(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** * @brief Generate all logical transformation rules by applying logical * transformation rules to logical operators in the group until saturated */ -class ExploreGroup : public OptimizerTask { +class ExploreGroup : public OptimizerTask { public: - ExploreGroup(Group *group, - std::shared_ptr> context) + ExploreGroup(Group *group, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_GROUP), group_(group) {} virtual void execute() override; private: - Group *group_; + Group *group_; }; /** @@ -168,16 +150,15 @@ class ExploreGroup : public OptimizerTask { * pattern * in the same group is found, also apply logical transformation rule for it. */ -class ExploreExpression : public OptimizerTask { +class ExploreExpression : public OptimizerTask { public: - ExploreExpression(GroupExpression *group_expr, - std::shared_ptr> context) + ExploreExpression(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_EXPR), group_expr_(group_expr) {} virtual void execute() override; private: - GroupExpression *group_expr_; + GroupExpression *group_expr_; }; /** @@ -186,11 +167,10 @@ class ExploreExpression : public OptimizerTask { +class ApplyRule : public OptimizerTask { public: - ApplyRule(GroupExpression *group_expr, - Rule *rule, - std::shared_ptr> context, bool explore = false) + ApplyRule(GroupExpression *group_expr, Rule *rule, + std::shared_ptr context, bool explore = false) : OptimizerTask(context, OptimizerTaskType::APPLY_RULE), group_expr_(group_expr), rule_(rule), @@ -198,8 +178,8 @@ class ApplyRule : public OptimizerTask { virtual void execute() override; private: - GroupExpression *group_expr_; - Rule *rule_; + GroupExpression *group_expr_; + Rule *rule_; bool explore_only; }; @@ -210,10 +190,9 @@ class ApplyRule : public OptimizerTask { * current expression's cost is larger than the upper bound of the current * group */ -class OptimizeInputs : public OptimizerTask { +class OptimizeInputs : public OptimizerTask { public: - OptimizeInputs(GroupExpression *group_expr, - std::shared_ptr> context) + OptimizeInputs(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_INPUTS), group_expr_(group_expr) {} @@ -231,7 +210,7 @@ class OptimizeInputs : public OptimizerTask std::vector, std::vector>>> output_input_properties_; - GroupExpression *group_expr_; + GroupExpression *group_expr_; double cur_total_cost_; int cur_child_idx_ = -1; int prev_child_idx_ = -1; @@ -243,11 +222,9 @@ class OptimizeInputs : public OptimizerTask * child group have the stats, if not, recursively derive the stats. This would * lazily collect the stats for the column needed */ -class DeriveStats : public OptimizerTask { +class DeriveStats : public OptimizerTask { public: - DeriveStats(GroupExpression *gexpr, - ExprSet required_cols, - std::shared_ptr> context) + DeriveStats(GroupExpression *gexpr, ExprSet required_cols, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::DERIVE_STATS), gexpr_(gexpr), required_cols_(required_cols) {} @@ -260,30 +237,56 @@ class DeriveStats : public OptimizerTask { virtual void execute() override; private: - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; }; + +/** + * @brief Higher abstraction above TopDownRewrite and BottomUpRewrite that + * implements functionality similar and relied upon by both TopDownRewrite + * and BottomUpRewrite. + */ +class RewriteTask : public OptimizerTask { + public: + RewriteTask(OptimizerTaskType type, GroupID group_id, + std::shared_ptr context, + RewriteRuleSetName rule_set_name) + : OptimizerTask(context, type), + group_id_(group_id), + rule_set_name_(rule_set_name) {} + + virtual void execute() override { + LOG_ERROR("RewriteTask::execute invoked directly and not on derived"); + PELOTON_ASSERT(0); + }; + + protected: + std::set GetUniqueChildGroupIDs(); + bool OptimizeCurrentGroup(bool replace_on_match); + + GroupID group_id_; + RewriteRuleSetName rule_set_name_; +}; + /** * @brief Apply top-down rewrite pass, take in a rule set which must fulfill * that the lower level rewrite in the operator tree will not enable upper * level rewrite. An example is predicate push-down. We only push the predicates * from the upper level to the lower level. */ -template -class TopDownRewrite : public OptimizerTask { +class TopDownRewrite : public RewriteTask { public: - TopDownRewrite(GroupID group_id, - std::shared_ptr> context, + TopDownRewrite(GroupID group_id, std::shared_ptr context, RewriteRuleSetName rule_set_name) - : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), - group_id_(group_id), - rule_set_name_(rule_set_name) {} + : RewriteTask(OptimizerTaskType::TOP_DOWN_REWRITE, group_id, context, rule_set_name), + replace_on_transform_(true) {} + + void SetReplaceOnTransform(bool replace) { replace_on_transform_ = replace; } virtual void execute() override; private: - GroupID group_id_; - RewriteRuleSetName rule_set_name_; + bool replace_on_transform_; }; /** @@ -291,21 +294,16 @@ class TopDownRewrite : public OptimizerTask { * that the upper level rewrite in the operator tree will not enable lower * level rewrite. */ -template -class BottomUpRewrite : public OptimizerTask { +class BottomUpRewrite : public RewriteTask { public: - BottomUpRewrite(GroupID group_id, - std::shared_ptr> context, + BottomUpRewrite(GroupID group_id, std::shared_ptr context, RewriteRuleSetName rule_set_name, bool has_optimized_child) - : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), - group_id_(group_id), - rule_set_name_(rule_set_name), + : RewriteTask(OptimizerTaskType::BOTTOM_UP_REWRITE, group_id, context, rule_set_name), has_optimized_child_(has_optimized_child) {} + virtual void execute() override; private: - GroupID group_id_; - RewriteRuleSetName rule_set_name_; bool has_optimized_child_; }; } // namespace optimizer diff --git a/src/include/optimizer/optimizer_task_pool.h b/src/include/optimizer/optimizer_task_pool.h index 2ce755e8de0..4165b865ac5 100644 --- a/src/include/optimizer/optimizer_task_pool.h +++ b/src/include/optimizer/optimizer_task_pool.h @@ -25,34 +25,32 @@ namespace optimizer { * structure for multi-threaded optimization */ -template class OptimizerTaskPool { public: - virtual std::unique_ptr> Pop() = 0; - virtual void Push(OptimizerTask *task) = 0; + virtual std::unique_ptr Pop() = 0; + virtual void Push(OptimizerTask *task) = 0; virtual bool Empty() = 0; }; /** * @brief Stack implementation of the task pool */ -template -class OptimizerTaskStack : public OptimizerTaskPool { +class OptimizerTaskStack : public OptimizerTaskPool { public: - virtual std::unique_ptr> Pop() { + virtual std::unique_ptr Pop() { auto task = std::move(task_stack_.top()); task_stack_.pop(); return task; } - virtual void Push(OptimizerTask *task) { - task_stack_.push(std::unique_ptr>(task)); + virtual void Push(OptimizerTask *task) { + task_stack_.push(std::unique_ptr(task)); } virtual bool Empty() { return task_stack_.empty(); } private: - std::stack>> task_stack_; + std::stack> task_stack_; }; } // namespace optimizer diff --git a/src/include/optimizer/pattern.h b/src/include/optimizer/pattern.h index 176fb382b9a..3db6eebeb6c 100644 --- a/src/include/optimizer/pattern.h +++ b/src/include/optimizer/pattern.h @@ -20,13 +20,11 @@ namespace peloton { namespace optimizer { -/** - * template parameter should *really* only be OpType or ExpressionType - */ -template class Pattern { public: - Pattern(OperatorType op); + Pattern(OpType op); + + Pattern(ExpressionType exp_type); void AddChild(std::shared_ptr child); @@ -34,10 +32,13 @@ class Pattern { inline size_t GetChildPatternsSize() const { return children.size(); } - OperatorType Type() const; + OpType GetOpType() const; + + ExpressionType GetExpType() const; private: - OperatorType _type; + OpType _op_type = OpType::Undefined; + ExpressionType _exp_type = ExpressionType::INVALID; std::vector> children; }; diff --git a/src/include/optimizer/property_enforcer.h b/src/include/optimizer/property_enforcer.h index c826edbe54d..e82b802d84c 100644 --- a/src/include/optimizer/property_enforcer.h +++ b/src/include/optimizer/property_enforcer.h @@ -30,8 +30,8 @@ class PropertyEnforcer : public PropertyVisitor { public: - std::shared_ptr> EnforceProperty( - GroupExpression* gexpr, Property* property); + std::shared_ptr EnforceProperty( + GroupExpression* gexpr, Property* property); virtual void Visit(const PropertyColumns *) override; virtual void Visit(const PropertySort *) override; @@ -39,8 +39,8 @@ class PropertyEnforcer : public PropertyVisitor { virtual void Visit(const PropertyLimit *) override; private: - GroupExpression* input_gexpr_; - std::shared_ptr> output_gexpr_; + GroupExpression* input_gexpr_; + std::shared_ptr output_gexpr_; }; } // namespace optimizer diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h index 796b10f7779..161692a1a70 100644 --- a/src/include/optimizer/rewriter.h +++ b/src/include/optimizer/rewriter.h @@ -25,26 +25,23 @@ namespace optimizer { class Rewriter { public: - Rewriter(const Rewriter &) = delete; - Rewriter &operator=(const Rewriter &) = delete; - Rewriter(Rewriter &&) = delete; - Rewriter &operator=(Rewriter &&) = delete; - Rewriter(); - - expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); void Reset(); - OptimizerMetadata &GetMetadata() { return metadata_; } + DISALLOW_COPY_AND_MOVE(Rewriter); - std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + OptimizerMetadata &GetMetadata() { return metadata_; } + + expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); private: expression::AbstractExpression* RebuildExpression(int root_group); - void ExecuteTaskStack(OptimizerTaskStack &task_stack); void RewriteLoop(int root_group_id); - std::shared_ptr> ConvertTree(const expression::AbstractExpression *expr); - OptimizerMetadata metadata_; + + std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + std::shared_ptr RecordTreeGroups(const expression::AbstractExpression *expr); + + OptimizerMetadata metadata_; }; } // namespace optimizer diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index fba7e985b5a..cbd21a79738 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -113,7 +113,8 @@ struct RuleWithPromise { enum class RewriteRuleSetName : uint32_t { PREDICATE_PUSH_DOWN = 0, UNNEST_SUBQUERY, - COMPARATOR_ELIMINATION + EQUIVALENT_TRANSFORM, + GENERIC_RULES }; /** diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index fe0f2b829bf..8df83556626 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -27,19 +27,49 @@ enum class RulePriority : int { LOW = 1 }; -class ComparatorElimination: public Rule { +class ComparatorElimination: public Rule { public: - ComparatorElimination(); + ComparatorElimination(RuleType rule, ExpressionType root); - int Promise(GroupExpression *group_expr, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class EquivalentTransform: public Rule { + public: + EquivalentTransform(RuleType rule, ExpressionType root); - bool Check(std::shared_ptr plan, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class TVEqualityWithTwoCVTransform: public Rule { + public: + TVEqualityWithTwoCVTransform(); - void Transform(std::shared_ptr input, - std::vector> &transformed, - OptimizeContext *context) const override; + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; }; + +class TransitiveClosureConstantTransform: public Rule { + public: + TransitiveClosureConstantTransform(); + + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + } // namespace optimizer } // namespace peloton diff --git a/src/include/optimizer/stats/child_stats_deriver.h b/src/include/optimizer/stats/child_stats_deriver.h index cfca18e30d9..f4f3c05be20 100644 --- a/src/include/optimizer/stats/child_stats_deriver.h +++ b/src/include/optimizer/stats/child_stats_deriver.h @@ -21,7 +21,6 @@ class AbstractExpression; } namespace optimizer { -template class Memo; class OperatorExpression; @@ -31,9 +30,9 @@ class OperatorExpression; class ChildStatsDeriver : public OperatorVisitor { public: std::vector DeriveInputStats( - GroupExpression *gexpr, + GroupExpression *gexpr, ExprSet required_cols, - Memo *memo); + Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalInnerJoin *) override; @@ -47,8 +46,8 @@ class ChildStatsDeriver : public OperatorVisitor { void PassDownRequiredCols(); void PassDownColumn(expression::AbstractExpression* col); ExprSet required_cols_; - GroupExpression *gexpr_; - Memo *memo_; + GroupExpression *gexpr_; + Memo *memo_; std::vector output_; }; diff --git a/src/include/optimizer/stats/stats_calculator.h b/src/include/optimizer/stats/stats_calculator.h index 6fed68370f9..79d1988de95 100644 --- a/src/include/optimizer/stats/stats_calculator.h +++ b/src/include/optimizer/stats/stats_calculator.h @@ -17,7 +17,6 @@ namespace peloton { namespace optimizer { -template class Memo; class TableStats; class OperatorExpression; @@ -28,9 +27,9 @@ class OperatorExpression; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, + void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, + Memo *memo, concurrency::TransactionContext* txn); void Visit(const LogicalGet *) override; @@ -76,9 +75,9 @@ class StatsCalculator : public OperatorVisitor { const std::shared_ptr predicate_table_stats, const expression::AbstractExpression *expr); - GroupExpression *gexpr_; + GroupExpression *gexpr_; ExprSet required_cols_; - Memo *memo_; + Memo *memo_; concurrency::TransactionContext* txn_; }; diff --git a/src/include/traffic_cop/traffic_cop.h b/src/include/traffic_cop/traffic_cop.h index e324b87fe82..8870591a4df 100644 --- a/src/include/traffic_cop/traffic_cop.h +++ b/src/include/traffic_cop/traffic_cop.h @@ -25,6 +25,7 @@ #include "common/statement.h" #include "executor/plan_executor.h" #include "optimizer/abstract_optimizer.h" +#include "optimizer/rewriter.h" #include "parser/sql_statement.h" #include "type/type.h" @@ -196,6 +197,8 @@ class TrafficCop { // still a HACK void GetTableColumns(parser::TableRef *from_table, std::vector &target_tables); + + optimizer::Rewriter rewriter_; }; } // namespace tcop diff --git a/src/optimizer/absexpr_expression.cpp b/src/optimizer/absexpr_expression.cpp new file mode 100644 index 00000000000..c0e8d5ca8da --- /dev/null +++ b/src/optimizer/absexpr_expression.cpp @@ -0,0 +1,146 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_expression.cpp +// +// Identification: src/optimizer/absexpr_expression.cpp +// +//===----------------------------------------------------------------------===// + +#include "optimizer/absexpr_expression.h" +#include "expression/operator_expression.h" +#include "expression/aggregate_expression.h" +#include "expression/star_expression.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +expression::AbstractExpression *AbsExpr_Container::CopyWithChildren(std::vector children) { + // Pre-compute left and right + expression::AbstractExpression *left = nullptr; + expression::AbstractExpression *right = nullptr; + if (children.size() >= 2) { + left = children[0]; + right = children[1]; + } else if (children.size() == 1) { + left = children[0]; + } + + auto type = GetExpType(); + switch (type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_LIKE: + case ExpressionType::COMPARE_NOTLIKE: + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_DISTINCT_FROM: { + // Create new expression with 2 new children of the same type + return new expression::ComparisonExpression(type, left, right); + } + + case ExpressionType::CONJUNCTION_AND: + case ExpressionType::CONJUNCTION_OR: { + // Create new expression with the new children + return new expression::ConjunctionExpression(type, left, right); + } + + case ExpressionType::OPERATOR_PLUS: + case ExpressionType::OPERATOR_MINUS: + case ExpressionType::OPERATOR_MULTIPLY: + case ExpressionType::OPERATOR_DIVIDE: + case ExpressionType::OPERATOR_CONCAT: + case ExpressionType::OPERATOR_MOD: + case ExpressionType::OPERATOR_NOT: + case ExpressionType::OPERATOR_IS_NULL: + case ExpressionType::OPERATOR_IS_NOT_NULL: + case ExpressionType::OPERATOR_EXISTS: { + // Create new expression, preserving return_value_type_ + type::TypeId ret = expr->GetValueType(); + return new expression::OperatorExpression(type, ret, left, right); + } + + case ExpressionType::OPERATOR_UNARY_MINUS: { + PELOTON_ASSERT(children.size() == 1); + return new expression::OperatorUnaryMinusExpression(left); + } + + case ExpressionType::STAR: + case ExpressionType::VALUE_CONSTANT: + case ExpressionType::VALUE_PARAMETER: + case ExpressionType::VALUE_TUPLE: { + PELOTON_ASSERT(children.size() == 0); + return expr->Copy(); + } + + case ExpressionType::AGGREGATE_COUNT: + case ExpressionType::AGGREGATE_COUNT_STAR: + case ExpressionType::AGGREGATE_SUM: + case ExpressionType::AGGREGATE_MIN: + case ExpressionType::AGGREGATE_MAX: + case ExpressionType::AGGREGATE_AVG: { + // We should not be changing # of children of AggregateExpression + PELOTON_ASSERT(expr->GetChildrenSize() == children.size()); + + // Unfortunately, the aggregate_expression (also applies to function) + // may already have extra state information created due to the binder. + // Under Peloton's design, we decide to just copy() the node and then + // install the child. + auto expr_copy = expr->Copy(); + if (children.size() == 1) { + // If we updated the child, install the child + expr_copy->SetChild(0, children[0]); + } + + return expr_copy; + } + + case ExpressionType::FUNCTION: { + // We really should not be modifying # children of Function + PELOTON_ASSERT(children.size() == expr->GetChildrenSize()); + auto copy = expr->Copy(); + + size_t num_child = children.size(); + for (size_t i = 0; i < num_child; i++) { + copy->SetChild(i, children[i]); + } + return copy; + } + + // Rewriting for these 2 uses special matching patterns. + // As such, when building as an output, we just copy directly. + case ExpressionType::ROW_SUBQUERY: + case ExpressionType::OPERATOR_CASE_EXPR: { + PELOTON_ASSERT(children.size() == 0); + return expr->Copy(); + } + + // These ExpressionTypes are never instantiated as a type + case ExpressionType::PLACEHOLDER: + case ExpressionType::COLUMN_REF: + case ExpressionType::FUNCTION_REF: + case ExpressionType::TABLE_REF: + case ExpressionType::SELECT_SUBQUERY: + case ExpressionType::VALUE_TUPLE_ADDRESS: + case ExpressionType::VALUE_NULL: + case ExpressionType::VALUE_VECTOR: + case ExpressionType::VALUE_SCALAR: + case ExpressionType::HASH_RANGE: + case ExpressionType::OPERATOR_CAST: + default: { + int type = static_cast(GetExpType()); + LOG_ERROR("Unimplemented Rebuild() for %d found", type); + return expr->Copy(); + } + } +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index e482d909b6b..807c4c42f94 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -16,6 +16,7 @@ #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" #include "optimizer/absexpr_expression.h" +#include "expression/group_marker_expression.h" namespace peloton { namespace optimizer { @@ -34,11 +35,12 @@ GroupBindingIterator::GroupBindingIterator(Memo &memo, GroupID id, LOG_TRACE("Attempting to bind on group %d", id); } -bool GroupBindingIterator::HasNextBinding() { +bool GroupBindingIterator::HasNext() { LOG_TRACE("HasNextBinding"); - // TODO(ncx): pattern - if (pattern_->Type() == OpType::Leaf) { + // TODO(): Can we do this generic pattern any better? + if ((pattern_->GetOpType() == OpType::Leaf && pattern_->GetExpType() == ExpressionType::INVALID) || + (pattern_->GetOpType() == OpType::Undefined && pattern_->GetExpType() == ExpressionType::GROUP_MARKER)) { return current_item_index_ == 0; } @@ -71,12 +73,18 @@ bool GroupBindingIterator::HasNextBinding() { } std::shared_ptr GroupBindingIterator::Next() { - // TODO(ncx): pattern - if (pattern_->Type() == OpType::Leaf) { + if (pattern_->GetOpType() == OpType::Leaf && pattern_->GetExpType() == ExpressionType::INVALID) { current_item_index_ = num_group_items_; return std::make_shared(LeafOperator::make(group_id_)); } + if (pattern_->GetOpType() == OpType::Undefined && pattern_->GetExpType() == ExpressionType::GROUP_MARKER) { + current_item_index_ = num_group_items_; + + auto expr = std::make_shared(group_id_); + return std::make_shared(std::make_shared(expr)); + } + return current_iterator_->Next(); } @@ -89,13 +97,20 @@ GroupExprBindingIterator::GroupExprBindingIterator( gexpr_(gexpr), pattern_(pattern), first_(true), - has_next_(false), - // TODO(ncx): needs workaround when Node is not an Operator - current_binding_(std::make_shared(gexpr->Node())) { - if (gexpr->Node()->GetOpType() != pattern->Type()) { + has_next_(false) { + + if (gexpr->Node()->GetOpType() != pattern->GetOpType() || + gexpr->Node()->GetExpType() != pattern->GetExpType()) { return; } + // Create right type of AbstractNodeExpression depending on type + if (gexpr->Node()->GetOpType() != OpType::Undefined) { + current_binding_ = std::make_shared(gexpr->Node()); + } else { + current_binding_ = std::make_shared(gexpr->Node()); + } + const std::vector &child_groups = gexpr->GetChildGroupIDs(); const std::vector> &child_patterns = pattern->Children(); @@ -180,12 +195,6 @@ std::shared_ptr GroupExprBindingIterator::Next() { return current_binding_; } -// Explicitly instantiate -template class GroupBindingIterator; -template class GroupExprBindingIterator; - -template class GroupBindingIterator; -template class GroupExprBindingIterator; } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index fff1d7a531b..c61cdc81d25 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -15,6 +15,7 @@ #include "optimizer/operators.h" #include "optimizer/stats/stats_calculator.h" #include "optimizer/absexpr_expression.h" +#include "expression/group_marker_expression.h" namespace peloton { namespace optimizer { @@ -31,7 +32,8 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced) { // If leaf, then just return - if (gexpr->Node()->GetOpType() == OpType::Leaf) { + if (gexpr->Node()->GetOpType() == OpType::Leaf && + gexpr->Node()->GetExpType() == ExpressionType::INVALID) { const LeafOperator *leaf = gexpr->Node()->As(); PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == leaf->origin_group); @@ -39,6 +41,19 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, return nullptr; } + if (gexpr->Node()->GetOpType() == OpType::Undefined && + gexpr->Node()->GetExpType() == ExpressionType::GROUP_MARKER) { + + auto abs_node = std::dynamic_pointer_cast(gexpr->Node()); + PELOTON_ASSERT(abs_node != nullptr); + + auto gm_expr = std::dynamic_pointer_cast(abs_node->GetExpr()); + PELOTON_ASSERT(gm_expr != nullptr); + PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == gm_expr->GetGroupID()); + gexpr->SetGroupID(gm_expr->GetGroupID()); + return nullptr; + } + auto it = group_expressions_.find(gexpr.get()); if (it != group_expressions_.end()) { gexpr->SetGroupID((*it)->GetGroupID()); diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index 286788294ae..24edd6d6876 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -32,12 +32,9 @@ void OptimizerTask::ConstructValidRules( std::vector &valid_rules) { for (auto &rule : rules) { // Check if we can apply the rule - // TODO(ncx): replace after pattern fix - // bool root_pattern_mismatch = - // group_expr->Node()->GetOpType() != rule->GetMatchPattern()->OpType() - // || group_expr->Node()->GetExpType() != rule->GetMatchPattern()->ExpType(); - bool root_pattern_mismatch = - group_expr->Node()->GetOpType() != rule->GetMatchPattern()->Type(); + bool root_pattern_mismatch = group_expr->Node()->GetOpType() != rule->GetMatchPattern()->GetOpType() || + group_expr->Node()->GetExpType() != rule->GetMatchPattern()->GetExpType(); + bool already_explored = group_expr->HasRuleExplored(rule.get()); bool child_pattern_mismatch = group_expr->GetChildrenGroupsSize() != @@ -406,122 +403,139 @@ void OptimizeInputs::execute() { } } -void TopDownRewrite::execute() { - std::vector valid_rules; - +// =================================================================== +// +// RewriteTask related implementations +// +// =================================================================== +std::set RewriteTask::GetUniqueChildGroupIDs() { + // Get current group and logical expressions auto cur_group = this->GetMemo().GetGroupByID(group_id_); - auto cur_group_expr = cur_group->GetLogicalExpression(); - - // Construct valid transformation rules from rule set - this->ConstructValidRules(cur_group_expr, this->context_.get(), - this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); - - // Sort so that we apply rewrite rules with higher promise first - std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); - - for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); - if (iterator.HasNext()) { - auto before = iterator.Next(); - PELOTON_ASSERT(!iterator.HasNext()); - - // (TODO): pending terrier issue #332 - // Check whether rule actually can be applied - // as opposed to a structural level test - if (!r.rule->Check(before, this->context_.get())) { - continue; - } - - std::vector> after; - r.rule->Transform(before, after, this->context_.get()); - - // Rewrite rule should provide at most 1 expression - PELOTON_ASSERT(after.size() <= 1); - // If a rule is applied, we replace the old expression and optimize this - // group again, this will ensure that we apply rule for this level until - // saturated - if (!after.empty()) { - auto &new_expr = after[0]; - this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - this->PushTask(new TopDownRewrite(group_id_, this->context_, rule_set_name_)); - return; - } + auto cur_group_exprs = cur_group->GetLogicalExpressions(); + PELOTON_ASSERT(cur_group_exprs.size() >= 1); + + // Generate unique group ID numbers so we don't repeat work + std::set child_groups; + for (auto cur_group_expr : cur_group_exprs) { + for (size_t child = 0; child < cur_group_expr->GetChildrenGroupsSize(); child++) { + child_groups.insert(cur_group_expr->GetChildGroupId(child)); } - cur_group_expr->SetRuleExplored(r.rule); } - for (size_t child_group_idx = 0; - child_group_idx < cur_group_expr->GetChildrenGroupsSize(); - child_group_idx++) { - // Need to rewrite all sub trees first - this->PushTask( - new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - this->context_, rule_set_name_)); - } + return child_groups; } -void BottomUpRewrite::execute() { +bool RewriteTask::OptimizeCurrentGroup(bool replace_on_match) { std::vector valid_rules; + // Get current group and logical expressions auto cur_group = this->GetMemo().GetGroupByID(group_id_); - auto cur_group_expr = cur_group->GetLogicalExpression(); + auto cur_group_exprs = cur_group->GetLogicalExpressions(); + PELOTON_ASSERT(cur_group_exprs.size() >= 1); + + // Try to optimize all the logical group expressions. + // If one gets optimized, then the group is collapsed. + for (auto cur_group_expr_ptr : cur_group_exprs) { + auto cur_group_expr = cur_group_expr_ptr.get(); + + // Construct valid transformation rules from rule set + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); + + // Sort so that we apply rewrite rules with higher promise first + std::sort(valid_rules.begin(), valid_rules.end(), + std::greater()); + + // Try applying each rule + for (auto &r : valid_rules) { + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, r.rule->GetMatchPattern()); + // Keep trying to apply until we exhaust all the bindings. + // This could possibly be sub-optimal since the first binding that results + // in a transformation by a rule will be applied and become the group's + // "new" rewritten expression. + while (iterator.HasNext()) { + // Binding succeeded to a given expression structure + auto before = iterator.Next(); + + // Attempt to apply the transformation + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); + + // Rewrite rule should provide at most 1 expression + PELOTON_ASSERT(after.size() <= 1); + if (!after.empty()) { + // The transformation produced another expression + auto &new_expr = after[0]; + if (replace_on_match) { + // Replace entire group. We do not need to generate logically equivalent + // because rewriting expressions will not generate new AND or OR clauses. + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + + // Return true to indicate optimize succeeded and the caller should try again + return true; + } else { + // Insert as a new logical equivalent expression + std::shared_ptr new_gexpr; + GroupID group = cur_group_expr->GetGroupID(); - if (!has_optimized_child_) { - this->PushTask(new BottomUpRewrite(group_id_, this->context_, rule_set_name_, true)); - for (size_t child_group_idx = 0; - child_group_idx < cur_group_expr->GetChildrenGroupsSize(); - child_group_idx++) { - // Need to rewrite all sub trees first - this->PushTask( - new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - this->context_, rule_set_name_, false)); + // Try again only if we succeeded in recording a new expression + return this->context_->metadata->RecordTransformedExpression(new_expr, new_gexpr, group); + } + } + } + + cur_group_expr->SetRuleExplored(r.rule); } - return; } - // Construct valid transformation rules from rule set - this->ConstructValidRules(cur_group_expr, this->context_.get(), - this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); - // Sort so that we apply rewrite rules with higher promise first - std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + return false; +} - for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); - if (iterator.HasNext()) { - auto before = iterator.Next(); - PELOTON_ASSERT(!iterator.HasNext()); - - // (TODO): pending terrier issue #332 - // Check whether rule actually can be applied - // as opposed to a structural level test - if (!r.rule->Check(before, this->context_.get())) { - continue; - } +void TopDownRewrite::execute() { + bool did_optimize = this->OptimizeCurrentGroup(replace_on_transform_); + + // Optimize succeeded and by the design, there will ever be only 1 + // that is logically equivalent, so we do not need to perform + // any extra passes. Equivalence generating rules will not be repeatedly + // applied to expression trees. + // + // This is definitely sub-optimal and is a missed opportunity for rewrite. + // However, this requires AbstractExpression to support strict equality + // in its post-binding state. + if (did_optimize && replace_on_transform_) { + auto top = new TopDownRewrite(this->group_id_, this->context_, this->rule_set_name_); + top->SetReplaceOnTransform(replace_on_transform_); + this->PushTask(top); + return; + } - std::vector> after; - r.rule->Transform(before, after, this->context_.get()); + // This group has been optimized, so move on to the children + std::set child_groups = this->GetUniqueChildGroupIDs(); + for (auto g_id : child_groups) { + auto top = new TopDownRewrite(g_id, this->context_, this->rule_set_name_); + top->SetReplaceOnTransform(replace_on_transform_); + this->PushTask(top); + } +} - // Rewrite rule should provide at most 1 expression - PELOTON_ASSERT(after.size() <= 1); - // If a rule is applied, we replace the old expression and optimize this - // group again, this will ensure that we apply rule for this level until - // saturated, also childs are already been rewritten - if (!after.empty()) { - auto &new_expr = after[0]; - this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - this->PushTask( - new BottomUpRewrite(group_id_, this->context_, rule_set_name_, false)); +void BottomUpRewrite::execute() { + if (!has_optimized_child_) { + this->PushTask(new BottomUpRewrite(this->group_id_, this->context_, this->rule_set_name_, true)); - return; - } + // Get all unique GroupIDs to minimize repeated work + // Need to rewrite all sub trees first + std::set child_groups = this->GetUniqueChildGroupIDs(); + for (auto g_id : child_groups) { + this->PushTask(new BottomUpRewrite(g_id, this->context_, this->rule_set_name_, false)); } - cur_group_expr->SetRuleExplored(r.rule); + + return; + } + + // Keep rewriting until we finish + if (this->OptimizeCurrentGroup(true)) { + this->PushTask(new BottomUpRewrite(this->group_id_, this->context_, this->rule_set_name_, false)); } } diff --git a/src/optimizer/pattern.cpp b/src/optimizer/pattern.cpp index 23b976888cf..81fd8b7d321 100644 --- a/src/optimizer/pattern.cpp +++ b/src/optimizer/pattern.cpp @@ -15,25 +15,19 @@ namespace peloton { namespace optimizer { -template -Pattern::Pattern(OperatorType op) : _type(op) {} +Pattern::Pattern(OpType op) : _op_type(op) {} +Pattern::Pattern(ExpressionType exp) : _exp_type(exp) {} -template -void Pattern::AddChild(std::shared_ptr> child) { +void Pattern::AddChild(std::shared_ptr child) { children.push_back(child); } -template -const std::vector>> &Pattern::Children() const { +const std::vector> &Pattern::Children() const { return children; } -template -OperatorType Pattern::Type() const { return _type; } - -// Explicitly instantiate -template class Pattern; -template class Pattern; +OpType Pattern::GetOpType() const { return _op_type; } +ExpressionType Pattern::GetExpType() const { return _exp_type; } } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp index d23d998e51d..4a440af56c6 100644 --- a/src/optimizer/rewriter.cpp +++ b/src/optimizer/rewriter.cpp @@ -38,43 +38,42 @@ using std::make_shared; namespace peloton { namespace optimizer { -using OptimizerMetadataTemplate = OptimizerMetadata; - -using OptimizeContextTemplate = OptimizeContext; - -using OptimizerTaskStackTemplate = OptimizerTaskStack; - -using TopDownRewriteTemplate = TopDownRewrite; - -using BottomUpRewriteTemplate = BottomUpRewrite; - -using GroupExpressionTemplate = GroupExpression; - -using GroupTemplate = Group; - Rewriter::Rewriter() : metadata_(nullptr) { - metadata_ = OptimizerMetadataTemplate(nullptr); + metadata_ = OptimizerMetadata(nullptr); +} + +void Rewriter::Reset() { + metadata_ = OptimizerMetadata(nullptr); } void Rewriter::RewriteLoop(int root_group_id) { - std::shared_ptr root_context = - std::make_shared(&metadata_, nullptr); - auto task_stack = - std::unique_ptr(new OptimizerTaskStackTemplate()); + std::shared_ptr root_context = std::make_shared(&metadata_, nullptr); + auto task_stack = std::unique_ptr(new OptimizerTaskStack()); metadata_.SetTaskPool(task_stack.get()); - // Perform rewrite first - task_stack->Push(new BottomUpRewriteTemplate(root_group_id, root_context, RewriteRuleSetName::COMPARATOR_ELIMINATION, false)); + // Rewrite using all rules (which will be applied based on priority) + task_stack->Push(new BottomUpRewrite(root_group_id, root_context, RewriteRuleSetName::GENERIC_RULES, false)); + + // Generate equivalences first + auto equiv_task = new TopDownRewrite(root_group_id, root_context, RewriteRuleSetName::EQUIVALENT_TRANSFORM); + equiv_task->SetReplaceOnTransform(false); // generate equivalent + task_stack->Push(equiv_task); - ExecuteTaskStack(*task_stack); + // Iterate through the task stack + while (!task_stack->Empty()) { + auto task = task_stack->Pop(); + task->execute(); + } } expression::AbstractExpression* Rewriter::RebuildExpression(int root) { auto cur_group = metadata_.memo.GetGroupByID(root); auto exprs = cur_group->GetLogicalExpressions(); - // (TODO): what should we do if exprs.size() > 1? - PELOTON_ASSERT(exprs.size() > 0); + // If we optimized a group successfully, then it would have been + // collapsed to only a single group. If we did not optimize a group, + // then they are all equivalent, so pick any. + PELOTON_ASSERT(exprs.size() >= 1); auto expr = exprs[0]; std::vector child_groups = expr->GetChildGroupIDs(); @@ -87,15 +86,49 @@ expression::AbstractExpression* Rewriter::RebuildExpression(int root) { child_exprs.push_back(child); } - AbsExpr_Container c = expr->Op(); - return c.Rebuild(child_exprs); + std::shared_ptr c = std::dynamic_pointer_cast(expr->Node()); + PELOTON_ASSERT(c != nullptr); + + return c->CopyWithChildren(child_exprs); +} + +std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { + // TODO(): remove the Copy invocation when in terrier since terrier uses shared_ptr + // + // This Copy() is not very efficient at all. but within Peloton, this is the only way + // to present a std::shared_ptr to the AbsExpr_Container/Expression classes. In terrier, + // this Copy() is *definitely* not needed because the AbstractExpression there already + // utilizes std::shared_ptr properly. + std::shared_ptr copy = std::shared_ptr(expr->Copy()); + + // Create current AbsExpr_Expression + auto container = std::make_shared(copy); + auto expression = std::make_shared(container); + + // Convert all the children + size_t child_count = expr->GetChildrenSize(); + for (size_t i = 0; i < child_count; i++) { + expression->PushChild(ConvertToAbsExpr(expr->GetChild(i))); + } + + copy->ClearChildren(); + return expression; +} + +std::shared_ptr Rewriter::RecordTreeGroups(const expression::AbstractExpression *expr) { + std::shared_ptr exp = ConvertToAbsExpr(expr); + std::shared_ptr gexpr; + metadata_.RecordTransformedExpression(exp, gexpr); + return gexpr; } expression::AbstractExpression* Rewriter::RewriteExpression(const expression::AbstractExpression *expr) { - // (TODO): do we need to actually convert to a wrapper? - // This is needed in order to provide template classes the correct interface. - // This should probably be better abstracted away. - std::shared_ptr gexpr = ConvertTree(expr); + if (expr == nullptr) + return nullptr; + + // This is needed in order to provide template classes the correct interface + // and also handle immutable AbstractExpression. + std::shared_ptr gexpr = RecordTreeGroups(expr); LOG_DEBUG("Converted tree to internal data structures"); GroupID root_id = gexpr->GetGroupID(); @@ -110,41 +143,5 @@ expression::AbstractExpression* Rewriter::RewriteExpression(const expression::Ab return expr_tree; } -void Rewriter::Reset() { - metadata_ = OptimizerMetadataTemplate(nullptr); -} - -std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { - - // (TODO): fix memory management once we get to terrier - // for now, this just directly wraps each AbstractExpression in a AbsExpr_Container - // which is then wrapped in an AbsExpr_Expression to provide the same Operator/OperatorExpression - // interface that is relied upon by the rest of the code base. - - auto container = AbsExpr_Container(expr); - auto exp = std::make_shared(container); - for (size_t i = 0; i < expr->GetChildrenSize(); i++) { - exp->PushChild(ConvertToAbsExpr(expr->GetChild(i))); - } - return exp; -} - -std::shared_ptr Rewriter::ConvertTree( - const expression::AbstractExpression *expr) { - - std::shared_ptr exp = ConvertToAbsExpr(expr); - std::shared_ptr gexpr; - metadata_.RecordTransformedExpression(exp, gexpr); - return gexpr; -} - -void Rewriter::ExecuteTaskStack(OptimizerTaskStackTemplate &task_stack) { - // Iterate through the task stack - while (!task_stack.Empty()) { - auto task = task_stack.Pop(); - task->execute(); - } -} - } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index fde53e6b89d..3baec8dba86 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -20,31 +20,63 @@ namespace optimizer { int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - // TODO(ncx): replace after pattern fix, and specialize to operators - // auto root_type = match_pattern->OpType(); - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); + auto root_type_exp = match_pattern->GetExpType(); + // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { + if (root_type != OpType::Undefined && + root_type != OpType::Leaf && + root_type != group_expr->Node()->GetOpType()) { + return 0; + } + + if (root_type_exp != ExpressionType::INVALID && + root_type_exp != ExpressionType::GROUP_MARKER && + root_type_exp != group_expr->Node()->GetExpType()) { return 0; } + if (IsPhysical()) return PHYS_PROMISE; return LOG_PROMISE; } +// TODO(ncx): best way to specialize for constructors? RuleSet::RuleSet() { - LOG_ERROR("Must invoke specialization of RuleSet constructor"); - PELOTON_ASSERT(0); -} + // Comparator Elimination related rules + std::vector> comp_elim_pairs = { + std::make_pair(RuleType::CONSTANT_COMPARE_EQUAL, ExpressionType::COMPARE_EQUAL), + std::make_pair(RuleType::CONSTANT_COMPARE_NOTEQUAL, ExpressionType::COMPARE_NOTEQUAL), + std::make_pair(RuleType::CONSTANT_COMPARE_LESSTHAN, ExpressionType::COMPARE_LESSTHAN), + std::make_pair(RuleType::CONSTANT_COMPARE_GREATERTHAN, ExpressionType::COMPARE_GREATERTHAN), + std::make_pair(RuleType::CONSTANT_COMPARE_LESSTHANOREQUALTO, ExpressionType::COMPARE_LESSTHANOREQUALTO), + std::make_pair(RuleType::CONSTANT_COMPARE_GREATERTHANOREQUALTO, ExpressionType::COMPARE_GREATERTHANOREQUALTO) + }; -// TODO(ncx): best way to specialize for constructors? -template <> -RuleSet::RuleSet() { - AddRewriteRule(RewriteRuleSetName::COMPARATOR_ELIMINATION, - new ComparatorElimination()); -} + for (auto &pair : comp_elim_pairs) { + AddRewriteRule( + RewriteRuleSetName::GENERIC_RULES, + new ComparatorElimination(pair.first, pair.second) + ); + } + + // Equivalent Transform related rules (flip AND, OR, EQUAL) + std::vector> equiv_pairs = { + std::make_pair(RuleType::EQUIV_AND, ExpressionType::CONJUNCTION_AND), + std::make_pair(RuleType::EQUIV_OR, ExpressionType::CONJUNCTION_OR), + std::make_pair(RuleType::EQUIV_COMPARE_EQUAL, ExpressionType::COMPARE_EQUAL) + }; + for (auto &pair : equiv_pairs) { + AddRewriteRule( + RewriteRuleSetName::EQUIVALENT_TRANSFORM, + new EquivalentTransform(pair.first, pair.second) + ); + } + + // Additional rules + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); -template <> -RuleSet::RuleSet() { + // Define transformation/implementation rules AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); @@ -64,6 +96,7 @@ RuleSet::RuleSet() { AddImplementationRule(new ImplementLimit()); AddImplementationRule(new LogicalExportToPhysicalExport()); + // Query optimizer related rewrite rules AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new PushFilterThroughJoin()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 24a3b541164..530bfedee99 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -1116,7 +1116,7 @@ MarkJoinToInnerJoin::MarkJoinToInnerJoin() { int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; @@ -1167,7 +1167,7 @@ SingleJoinToInnerJoin::SingleJoinToInnerJoin() { int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; @@ -1220,7 +1220,7 @@ PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() { int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; @@ -1281,7 +1281,7 @@ PullFilterThroughAggregation::PullFilterThroughAggregation() { int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index 88d23092c31..b804c08e488 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -3,6 +3,7 @@ #include "catalog/column_catalog.h" #include "catalog/index_catalog.h" #include "catalog/table_catalog.h" +#include "expression/group_marker_expression.h" #include "optimizer/operators.h" #include "optimizer/absexpr_expression.h" #include "optimizer/optimizer_metadata.h" @@ -14,78 +15,389 @@ namespace peloton { namespace optimizer { -ComparatorElimination::ComparatorElimination() { - type_ = RuleType::COMP_EQUALITY_ELIMINATION; +// =========================================================== +// +// ComparatorElimination related functions +// +// =========================================================== +ComparatorElimination::ComparatorElimination(RuleType rule, ExpressionType root) { + type_ = rule; - match_pattern = std::make_shared>(ExpressionType::COMPARE_EQUAL); - auto left = std::make_shared>(ExpressionType::VALUE_CONSTANT); - auto right = std::make_shared>(ExpressionType::VALUE_CONSTANT); + auto left = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right = std::make_shared(ExpressionType::VALUE_CONSTANT); + match_pattern = std::make_shared(root); match_pattern->AddChild(left); match_pattern->AddChild(right); } -int ComparatorElimination::Promise(GroupExpression *group_expr, - OptimizeContext *context) const { +int ComparatorElimination::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::HIGH); + return static_cast(RulePriority::MEDIUM); } -bool ComparatorElimination::Check(std::shared_ptr plan, - OptimizeContext *context) const { +bool ComparatorElimination::Check(std::shared_ptr plan, + OptimizeContext *context) const { (void)context; (void)plan; - - // If any of these assertions fail, something is seriously wrong with GroupExprBinding - // Verify the structure of the tree is correct - PELOTON_ASSERT(plan != nullptr); - PELOTON_ASSERT(plan->Children().size() == 2); - PELOTON_ASSERT(plan->Op().GetType() == ExpressionType::COMPARE_EQUAL); - - auto left = plan->Children()[0]; - auto right = plan->Children()[1]; - PELOTON_ASSERT(left->Children().size() == 0); - PELOTON_ASSERT(left->Op().GetType() == ExpressionType::VALUE_CONSTANT); - PELOTON_ASSERT(right->Children().size() == 0); - PELOTON_ASSERT(right->Op().GetType() == ExpressionType::VALUE_CONSTANT); - - // Technically, if structure matches, rule should always be applied return true; } -void ComparatorElimination::Transform(std::shared_ptr input, - std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { +void ComparatorElimination::Transform(std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { (void)transformed; (void)context; - // (TODO): create a wrapper for evaluating ConstantValue relations (pending email reply) + // Extract the AbstractExpression through indirection layer. + // Since the binding succeeded, there are guaranteed to be two children. + PELOTON_ASSERT(input->Children().size() == 2); + + auto left_abs = std::dynamic_pointer_cast(input->Children()[0]->Node()); + auto right_abs = std::dynamic_pointer_cast(input->Children()[1]->Node()); + PELOTON_ASSERT(left_abs != nullptr && right_abs != nullptr); - // Extract the AbstractExpression through indirection layer - auto left = input->Children()[0]->Op().GetExpr(); - auto right = input->Children()[1]->Op().GetExpr(); - auto lv = static_cast(left); - auto rv = static_cast(right); - lv = const_cast(lv); - rv = const_cast(rv); + auto left = left_abs->GetExpr(); + auto right = right_abs->GetExpr(); + auto lv = std::dynamic_pointer_cast(left); + auto rv = std::dynamic_pointer_cast(right); + PELOTON_ASSERT(lv != nullptr && rv != nullptr); // Get the Value from ConstantValueExpression auto lvalue = lv->GetValue(); auto rvalue = rv->GetValue(); + if (lvalue.CheckComparable(rvalue)) { + CmpBool cmp = CmpBool::CmpTrue; + switch (type_) { + case RuleType::CONSTANT_COMPARE_EQUAL: + cmp = lvalue.CompareEquals(rvalue); + break; + case RuleType::CONSTANT_COMPARE_NOTEQUAL: + cmp = lvalue.CompareNotEquals(rvalue); + break; + case RuleType::CONSTANT_COMPARE_LESSTHAN: + cmp = lvalue.CompareLessThan(rvalue); + break; + case RuleType::CONSTANT_COMPARE_GREATERTHAN: + cmp = lvalue.CompareGreaterThan(rvalue); + break; + case RuleType::CONSTANT_COMPARE_LESSTHANOREQUALTO: + // lv <= rv does not have a predefined function so we do the following: + // (1) Compute truth value of lvalue > rvalue + // (2) Flip the truth value unless CmpBool::NULL_ + cmp = lvalue.CompareGreaterThan(rvalue); + if (cmp != CmpBool::NULL_) { + cmp = (cmp == CmpBool::CmpFalse) ? CmpBool::CmpTrue : CmpBool::CmpFalse; + } + break; + case RuleType::CONSTANT_COMPARE_GREATERTHANOREQUALTO: + cmp = lvalue.CompareGreaterThanEquals(rvalue); + break; + default: + // Other type_ should not be handled by this rule + int type = static_cast(type_); + LOG_ERROR("lvalue compare rvalue with RuleType (%d) not implemented", type); + PELOTON_ASSERT(0); + break; + } + + // Create the replacement + type::Value val = type::ValueFactory::GetBooleanValue(cmp); + auto expr = std::make_shared(val); + auto container = std::make_shared(AbsExpr_Container(expr)); + auto shared = std::make_shared(container); + transformed.push_back(shared); + } + + // If the values cannot be comparable, we leave them as is. + // We don't throw an error or anything because it is possible this branch + // may be collapsed due to subsequent optimizations, and it is likely + // any error will be caught during actual query execution. + return; +} + +// =========================================================== +// +// EquivalentTransform related functions +// +// =========================================================== +EquivalentTransform::EquivalentTransform(RuleType rule, ExpressionType root) { + type_ = rule; + + auto left = std::make_shared(ExpressionType::GROUP_MARKER); + auto right = std::make_shared(ExpressionType::GROUP_MARKER); + match_pattern = std::make_shared(root); + match_pattern->AddChild(left); + match_pattern->AddChild(right); +} - // Need to check type equality to prevent assertion failure - // This is only a Peloton issue (terrier checks type for you) - bool is_equal = (lvalue.GetTypeId() == rvalue.GetTypeId()) && - (lv->ExactlyEquals(*rv)); +int EquivalentTransform::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} - // Create the transformed expression - type::Value val = type::ValueFactory::GetBooleanValue(is_equal); - auto eq = new expression::ConstantValueExpression(val); - auto cnt = AbsExpr_Container(eq); - auto shared = std::make_shared(cnt); +bool EquivalentTransform::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + return true; +} - // (TODO): figure out memory management once go to terrier (which use shared_ptr) +void EquivalentTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + UNUSED_ATTRIBUTE OptimizeContext *context) const { + (void)transformed; + (void)context; + + // We expect EquivalentTransform to operate on AND / OR which should + // have exactly 2 children for the expression to logically make sense. + PELOTON_ASSERT(input->Children().size() == 2); + + // Create flipped ordering + auto left = input->Children()[0]; + auto right = input->Children()[1]; + + // The children do not strictly matter anymore + auto type = match_pattern->GetExpType(); + auto expr = std::make_shared(type); + auto a_expr = std::make_shared(expr); + auto shared = std::make_shared(a_expr); + + // Create flipped ordering at logical level + shared->PushChild(right); + shared->PushChild(left); transformed.push_back(shared); + return; +} + + +// =========================================================== +// +// Transitive-Transform related functions +// +// =========================================================== +TVEqualityWithTwoCVTransform::TVEqualityWithTwoCVTransform() { + type_ = RuleType::TV_EQUALITY_WITH_TWO_CV; + + // (A.B = x) AND (A.B = y) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + + auto l_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto l_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto l_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + l_eq->AddChild(l_left); + l_eq->AddChild(l_right); + + auto r_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto r_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto r_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + r_eq->AddChild(r_left); + r_eq->AddChild(r_right); + + match_pattern->AddChild(l_eq); + match_pattern->AddChild(r_eq); +} + +int TVEqualityWithTwoCVTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool TVEqualityWithTwoCVTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; } + +void TVEqualityWithTwoCVTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + //TODO(wz2): TVEqualityWithTwoCVTransform should work beyond straight equality + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (A.B = x) AND (A.B = y) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + auto l_eq = input->Children()[0]; + auto r_eq = input->Children()[1]; + PELOTON_ASSERT(l_eq->Children().size() == 2); + PELOTON_ASSERT(r_eq->Children().size() == 2); + PELOTON_ASSERT(l_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + PELOTON_ASSERT(r_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + + auto l_tv = l_eq->Children()[0]; + auto l_cv = l_eq->Children()[1]; + PELOTON_ASSERT(l_tv->Children().size() == 0); + PELOTON_ASSERT(l_cv->Children().size() == 0); + PELOTON_ASSERT(l_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(l_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto r_tv = r_eq->Children()[0]; + auto r_cv = r_eq->Children()[1]; + PELOTON_ASSERT(r_tv->Children().size() == 0); + PELOTON_ASSERT(r_cv->Children().size() == 0); + PELOTON_ASSERT(r_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(r_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_c = std::dynamic_pointer_cast(r_tv->Node()); + auto l_cv_c = std::dynamic_pointer_cast(l_cv->Node()); + auto r_cv_c = std::dynamic_pointer_cast(r_cv->Node()); + PELOTON_ASSERT(l_tv_c != nullptr && r_tv_c != nullptr); + PELOTON_ASSERT(l_cv_c != nullptr && r_cv_c != nullptr); + + auto l_tv_expr = l_tv_c->GetExpr(); + auto r_tv_expr = r_tv_c->GetExpr(); + if (l_tv_expr->ExactlyEquals(*r_tv_expr)) { + // Given the pattern (A.B = x) AND (C.D = y), the IF statement asserts that A.B is the same as C.D + // TODO(wz2): ExactlyEquals may be too strict, since must match bound_oid, table_name, col_name + + // ExactlyEqual on TupleValueExpression has sufficient check + auto l_cv_expr = std::dynamic_pointer_cast(l_cv_c->GetExpr()); + auto r_cv_expr = std::dynamic_pointer_cast(r_cv_c->GetExpr()); + auto l_cv_val = l_cv_expr->GetValue(); + auto r_cv_val = r_cv_expr->GetValue(); + if (l_cv_val.CheckComparable(r_cv_val)) { + // Given a pattern (A.B = x) AND (A.B = y), we perform the following: + // - Rewrite expression to (A.B = x) if x == y + // - Rewrite expression to FALSE if x != y (including if x / y is NULL) + bool is_eq = false; + if (l_cv_val.CompareEquals(r_cv_val) == CmpBool::CmpTrue) { + // This means in the pattern (A.B = x) AND (A.B = y) that x == y + is_eq = true; + } + + if (is_eq) { + transformed.push_back(l_eq); + } else { + auto val = type::ValueFactory::GetBooleanValue(false); + auto constant = std::make_shared(val); + auto abs_expr = std::make_shared(std::make_shared(AbsExpr_Container(constant))); + transformed.push_back(abs_expr); + } + + return; + } + } + + return; +} + +TransitiveClosureConstantTransform::TransitiveClosureConstantTransform() { + type_ = RuleType::TRANSITIVE_CLOSURE_CONSTANT; + + // (A.B = x) AND (A.B = C.D) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + + auto l_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto l_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto l_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + l_eq->AddChild(l_left); + l_eq->AddChild(l_right); + + auto r_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto r_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto r_right = std::make_shared(ExpressionType::VALUE_TUPLE); + r_eq->AddChild(r_left); + r_eq->AddChild(r_right); + + match_pattern->AddChild(l_eq); + match_pattern->AddChild(r_eq); +} + +int TransitiveClosureConstantTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool TransitiveClosureConstantTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void TransitiveClosureConstantTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + //TODO(wz2): TransitiveClosureConstant should work beyond straight equality + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (A.B = x) AND (A.B = C.D) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + auto l_eq = input->Children()[0]; + auto r_eq = input->Children()[1]; + PELOTON_ASSERT(l_eq->Children().size() == 2); + PELOTON_ASSERT(r_eq->Children().size() == 2); + PELOTON_ASSERT(l_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + PELOTON_ASSERT(r_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + + auto l_tv = l_eq->Children()[0]; + auto l_cv = l_eq->Children()[1]; + PELOTON_ASSERT(l_tv->Children().size() == 0); + PELOTON_ASSERT(l_cv->Children().size() == 0); + PELOTON_ASSERT(l_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(l_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto r_tv_l = r_eq->Children()[0]; + auto r_tv_r = r_eq->Children()[1]; + PELOTON_ASSERT(r_tv_l->Children().size() == 0); + PELOTON_ASSERT(r_tv_r->Children().size() == 0); + PELOTON_ASSERT(r_tv_l->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(r_tv_r->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_l_c = std::dynamic_pointer_cast(r_tv_l->Node()); + auto r_tv_r_c = std::dynamic_pointer_cast(r_tv_r->Node()); + PELOTON_ASSERT(l_tv_c != nullptr && r_tv_l_c != nullptr && r_tv_r_c != nullptr); + + auto l_tv_expr = l_tv_c->GetExpr(); + auto r_tv_l_expr = r_tv_l_c->GetExpr(); + auto r_tv_r_expr = r_tv_r_c->GetExpr(); + + // At this stage, we have the arbitrary structure: (A.B = x) AND (C.D = E.F) + // TODO(wz2): ExactlyEquals for TupleValue may be too strict, since must match bound_oid, table_name, col_name + if (r_tv_l_expr->ExactlyEquals(*r_tv_r_expr)) { + // Handles case where C.D = E.F, which can rewrite to just A.B = x + transformed.push_back(l_eq); + return; + } + + if (!l_tv_expr->ExactlyEquals(*r_tv_l_expr) && !l_tv_expr->ExactlyEquals(*r_tv_r_expr)) { + // We know that A.B != C.D and A.B != E.F, so no optimization possible + return; + } + + auto new_left_eq = l_eq; + auto right_val_copy = l_cv; + auto new_right_eq = std::make_shared(r_eq->Node()); + + // At this stage, we have knowledge that C.D != E.F + if (l_tv_expr->ExactlyEquals(*r_tv_l_expr)) { + // At this stage, we have knowledge that A.B = C.D + new_right_eq->PushChild(right_val_copy); + new_right_eq->PushChild(r_tv_r); + } else { + // At this stage, we have knowledge that A.B = E.F + new_right_eq->PushChild(r_tv_l); + new_right_eq->PushChild(right_val_copy); + } + + // Create new root expression + auto abs_expr = std::make_shared(input->Node()); + abs_expr->PushChild(new_left_eq); + abs_expr->PushChild(new_right_eq); + transformed.push_back(abs_expr); +} + } // namespace optimizer } // namespace peloton diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index c6f8df81fda..31dab68579f 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -317,6 +317,24 @@ std::shared_ptr TrafficCop::PrepareStatement( tcop_txn_state_.top().first, default_database_name_); bind_node_visitor.BindNameToNode( statement->GetStmtParseTreeList()->GetStatement(0)); + + // Apply the rewriter if possible on the top statement + // TODO(): Rewrite more; move into optimizer and rewrite after unnesting + auto top_stmt = statement->GetStmtParseTreeList()->GetStatement(0); + switch (top_stmt->GetType()) { + case StatementType::SELECT: { + auto select = dynamic_cast(top_stmt); + + expression::AbstractExpression *where = select->where_clause.get(); + auto optimal = rewriter_.RewriteExpression(where); + select->UpdateWhereClause(optimal); + rewriter_.Reset(); + break; + } + default: + break; + } + auto plan = optimizer_->BuildPelotonPlanTree( statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); statement->SetPlanTree(plan); @@ -563,15 +581,15 @@ ResultType TrafficCop::ExecuteStatement( statement, std::move(param_stats)); } - LOG_TRACE("Execute Statement of name: %s", + LOG_DEBUG("Execute Statement of name: %s", statement->GetStatementName().c_str()); - LOG_TRACE("Execute Statement of query: %s", + LOG_DEBUG("Execute Statement of query: %s", statement->GetQueryString().c_str()); - LOG_TRACE("Execute Statement Plan:\n%s", + LOG_DEBUG("Execute Statement Plan:\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - LOG_TRACE("Execute Statement Query Type: %s", + LOG_DEBUG("Execute Statement Query Type: %s", statement->GetQueryTypeString().c_str()); - LOG_TRACE("----QueryType: %d--------", + LOG_DEBUG("----QueryType: %d--------", static_cast(statement->GetQueryType())); try { diff --git a/test/include/optimizer/mock_task.h b/test/include/optimizer/mock_task.h index 7e18f458445..32e5e1b8da4 100644 --- a/test/include/optimizer/mock_task.h +++ b/test/include/optimizer/mock_task.h @@ -20,10 +20,10 @@ namespace peloton { namespace optimizer { namespace test { -class MockTask : public optimizer::OptimizerTask { +class MockTask : public optimizer::OptimizerTask { public: MockTask() - : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} + : optimizer::OptimizerTask(nullptr, OptimizerTaskType::OPTIMIZE_GROUP) {} MOCK_METHOD0(execute, void()); }; diff --git a/test/optimizer/absexpr_test.cpp b/test/optimizer/absexpr_test.cpp new file mode 100644 index 00000000000..1e8d233c82d --- /dev/null +++ b/test/optimizer/absexpr_test.cpp @@ -0,0 +1,460 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_test.cpp +// +// Identification: test/optimizer/absexpr_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include "common/harness.h" + +#include "function/functions.h" +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/aggregate_expression.h" +#include "expression/conjunction_expression.h" +#include "expression/subquery_expression.h" +#include "expression/parameter_value_expression.h" +#include "expression/star_expression.h" +#include "expression/case_expression.h" +#include "expression/tuple_value_expression.h" +#include "expression/operator_expression.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "expression/function_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" +#include "parser/postgresparser.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class AbsExprTest : public PelotonTest { + public: + // Returns expression of (Constant = Tuple Value) + expression::AbstractExpression *GetTVEqualCVExpression(std::string col, int val) { + auto val_e = GetConstantExpression(val); + auto tuple = new expression::TupleValueExpression(std::move(col)); + auto cmp = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, val_e, tuple + ); + + return cmp; + } + + // Returns ConstantExpression(val) + expression::AbstractExpression *GetConstantExpression(int val) { + auto value = type::ValueFactory::GetIntegerValue(val); + return new expression::ConstantValueExpression(value); + } +}; + +TEST_F(AbsExprTest, CompareTest) { + std::vector compares = { + ExpressionType::COMPARE_EQUAL, + ExpressionType::COMPARE_NOTEQUAL, + ExpressionType::COMPARE_LESSTHAN, + ExpressionType::COMPARE_GREATERTHAN, + ExpressionType::COMPARE_LESSTHANOREQUALTO, + ExpressionType::COMPARE_GREATERTHANOREQUALTO, + ExpressionType::COMPARE_LIKE, + ExpressionType::COMPARE_NOTLIKE, + ExpressionType::COMPARE_IN, + ExpressionType::COMPARE_DISTINCT_FROM + }; + + auto left = new expression::ParameterValueExpression(0); + auto right = new expression::ParameterValueExpression(1); + for (auto type : compares) { + auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); + AbsExpr_Container op = AbsExpr_Container(cmp_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(cmp_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(cmp_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(cmp_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + EXPECT_EQ(*(cmp_expr->GetChild(0)), *(rebuilt->GetChild(0))); + EXPECT_EQ(*(cmp_expr->GetChild(1)), *(rebuilt->GetChild(1))); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->GetValueIdx() == 0 && r_child->GetValueIdx() == 1); + + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, ConjunctionTest) { + std::vector compares = { + ExpressionType::CONJUNCTION_AND, + ExpressionType::CONJUNCTION_OR + }; + + auto tval = type::ValueFactory::GetBooleanValue(true); + auto fval = type::ValueFactory::GetBooleanValue(false); + auto left = new expression::ConstantValueExpression(tval); + auto right = new expression::ConstantValueExpression(fval); + for (auto type : compares) { + auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); + AbsExpr_Container op = AbsExpr_Container(cmp_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(cmp_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(cmp_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(cmp_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*left) && r_child->ExactlyEquals(*right)); + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, OperatorTest) { + std::vector operators = { + ExpressionType::OPERATOR_PLUS, + ExpressionType::OPERATOR_MINUS, + ExpressionType::OPERATOR_MULTIPLY, + ExpressionType::OPERATOR_DIVIDE, + ExpressionType::OPERATOR_CONCAT, + ExpressionType::OPERATOR_MOD, + }; + + std::vector single_ops = { + ExpressionType::OPERATOR_NOT, + ExpressionType::OPERATOR_IS_NULL, + ExpressionType::OPERATOR_IS_NOT_NULL, + ExpressionType::OPERATOR_EXISTS + }; + + auto left = GetConstantExpression(25); + auto right = GetConstantExpression(30); + for (auto type : operators) { + auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), right->Copy()); + op_expr->DeduceExpressionType(); + + AbsExpr_Container op = AbsExpr_Container(op_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(op_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(op_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(op_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*(op_expr->GetChild(0)))); + EXPECT_TRUE(r_child->ExactlyEquals(*(op_expr->GetChild(1)))); + EXPECT_TRUE(l_child->ExactlyEquals(*left)); + EXPECT_TRUE(r_child->ExactlyEquals(*right)); + + delete rebuilt; + } + + for (auto type : single_ops) { + auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), nullptr); + op_expr->DeduceExpressionType(); + + AbsExpr_Container op = AbsExpr_Container(op_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(op_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(op_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(op_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + EXPECT_TRUE(l_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*(op_expr->GetChild(0)))); + EXPECT_TRUE(l_child->ExactlyEquals(*left)); + + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, OperatorUnaryMinusTest) { + auto left = GetConstantExpression(25); + auto unary = std::make_shared(left->Copy()); + + AbsExpr_Container op = AbsExpr_Container(unary); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(unary->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(unary->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(unary->GetChildrenSize(), rebuilt->GetChildrenSize()); + EXPECT_TRUE(unary->GetChild(0)->ExactlyEquals(*(rebuilt->GetChild(0)))); + EXPECT_TRUE(left->ExactlyEquals(*(rebuilt->GetChild(0)))); + delete rebuilt; + delete left; +} + +TEST_F(AbsExprTest, StarTest) { + auto expr = std::make_shared(); + AbsExpr_Container op = AbsExpr_Container(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueConstantTest) { + auto cv_expr = dynamic_cast(GetConstantExpression(721)); + auto expr = std::shared_ptr(cv_expr); + AbsExpr_Container op = AbsExpr_Container(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); // this does not check value + EXPECT_EQ(expr->GetValueType(), rebuilt->GetValueType()); + + auto lvalue = expr->GetValue(); + + auto rebuilt_val = dynamic_cast(rebuilt); + auto rvalue = rebuilt_val->GetValue(); + EXPECT_TRUE(lvalue.CheckComparable(expr->GetValue())); + EXPECT_TRUE(lvalue.CheckComparable(rvalue)); + EXPECT_TRUE(lvalue.CompareEquals(expr->GetValue()) == CmpBool::CmpTrue); + EXPECT_TRUE(lvalue.CompareEquals(rvalue) == CmpBool::CmpTrue); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueParameterTest) { + auto expr = std::make_shared(15); + AbsExpr_Container op = AbsExpr_Container(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); // does not check value_idx_ + + auto rebuilt_val = dynamic_cast(rebuilt); + EXPECT_EQ(expr->GetValueIdx(), rebuilt_val->GetValueIdx()); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueTupleTest) { + auto expr_col = std::make_shared("col"); + expr_col->SetTupleValueExpressionParams(type::TypeId::INTEGER, 1, 1); + expr_col->SetTableName("tbl"); + + AbsExpr_Container op = AbsExpr_Container(expr_col); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr_col, *rebuilt); // checks tbl_name, col_name + + auto rebuilt_val = dynamic_cast(rebuilt); + EXPECT_EQ(rebuilt_val->GetColumnId(), expr_col->GetColumnId()); + EXPECT_EQ(rebuilt_val->GetTableName(), expr_col->GetTableName()); + EXPECT_EQ(rebuilt_val->GetColumnName(), expr_col->GetColumnName()); + EXPECT_EQ(rebuilt_val->GetTupleId(), expr_col->GetTupleId()); + EXPECT_EQ(rebuilt_val->GetValueType(), expr_col->GetValueType()); + delete rebuilt; +} + +TEST_F(AbsExprTest, AggregateNodeTest) { + std::vector aggregates = { + ExpressionType::AGGREGATE_COUNT, + ExpressionType::AGGREGATE_SUM, + ExpressionType::AGGREGATE_MIN, + ExpressionType::AGGREGATE_MAX, + ExpressionType::AGGREGATE_AVG + }; + + // Generic aggregation + for (auto type : aggregates) { + auto child = new expression::TupleValueExpression("col_a"); + auto agg_expr = std::make_shared(type, true, child->Copy()); + agg_expr->DeduceExpressionType(); + + AbsExpr_Container op = AbsExpr_Container(agg_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({child->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + rebuilt->DeduceExpressionType(); + EXPECT_EQ(agg_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(agg_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(agg_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + EXPECT_EQ(*(agg_expr->GetChild(0)), *(rebuilt->GetChild(0))); + + EXPECT_TRUE(agg_expr->distinct_); + EXPECT_TRUE(rebuilt->distinct_); + + delete rebuilt; + delete child; + } + + // COUNT (*) Aggregation + auto child = new expression::StarExpression(); + auto agg_expr = std::make_shared(ExpressionType::AGGREGATE_COUNT, true, child); + + agg_expr->DeduceExpressionType(); + EXPECT_TRUE(agg_expr->GetExpressionType() == ExpressionType::AGGREGATE_COUNT_STAR); + + AbsExpr_Container op = AbsExpr_Container(agg_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(agg_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(agg_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(agg_expr->distinct_, rebuilt->distinct_); + EXPECT_EQ(rebuilt->GetChildrenSize(), 0); + delete rebuilt; +} + +TEST_F(AbsExprTest, CaseExpressionTest) { + auto where1 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_a", 1)); + auto where2 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_b", 2)); + auto where3 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_c", 3)); + auto def_c = expression::CaseExpression::AbsExprPtr(GetConstantExpression(4)); + + auto res1 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(1)); + auto res2 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(2)); + auto res3 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(3)); + std::vector clauses; + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where1), std::move(res1))); + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where2), std::move(res2))); + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where3), std::move(res3))); + + auto expr = std::make_shared(type::TypeId::INTEGER, clauses, std::move(def_c)); + AbsExpr_Container op = AbsExpr_Container(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + // Checks every clause except for ConstantValue values + EXPECT_EQ(*expr, *rebuilt); + + auto rebuilt_c = dynamic_cast(rebuilt); + EXPECT_TRUE(rebuilt_c->GetWhenClauseSize() == 3); + + // Check each when clause + for (int i = 0; i < 3; i++) { + auto res = rebuilt_c->GetWhenClauseResult(i); + auto res_c = dynamic_cast(res); + EXPECT_TRUE(res_c != nullptr && res_c->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(res_c->GetValue()) == (i + 1)); + + auto cond = rebuilt_c->GetWhenClauseCond(i); + EXPECT_TRUE(cond->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(cond->GetChildrenSize() == 2); + + auto rval = cond->GetChild(0); + EXPECT_TRUE(rval->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto rval_c = dynamic_cast(rval); + EXPECT_TRUE(rval_c != nullptr && rval_c->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(rval_c->GetValue()) == (i + 1)); + } + + // Check default clause + auto def_expr = dynamic_cast(rebuilt_c->GetDefault()); + EXPECT_TRUE(def_expr != nullptr); + EXPECT_TRUE(def_expr->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(def_expr->GetValue()) == 4); + delete rebuilt; +} + +TEST_F(AbsExprTest, SubqueryTest) { + std::vector> stmts; + { + auto parser = parser::PostgresParser::GetInstance(); + auto query = "SELECT * from foo"; + std::unique_ptr stmt_list(parser.BuildParseTree(query).release()); + stmts = std::move(stmt_list->statements); + } + + EXPECT_TRUE(stmts.size() == 1); + EXPECT_EQ(stmts[0]->GetType(), StatementType::SELECT); + parser::SelectStatement *sel = dynamic_cast(stmts[0].release()); + EXPECT_TRUE(sel != nullptr); + + auto expr = std::make_shared(); + expr->SetSubSelect(sel); + + AbsExpr_Container container = AbsExpr_Container(expr); + expression::AbstractExpression *rebuild = container.CopyWithChildren({}); + + EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); + EXPECT_EQ(rebuild->GetChildrenSize(), expr->GetChildrenSize()); + + auto rebuild_s = dynamic_cast(rebuild); + EXPECT_TRUE(rebuild_s != nullptr); + EXPECT_TRUE(rebuild_s->GetSubSelect() == expr->GetSubSelect()); + delete rebuild; +} + +TEST_F(AbsExprTest, FunctionExpressionTest) { + std::vector child1; + std::vector child2; + std::vector types; + for (int i = 0; i < 10; i++) { + child1.push_back(GetConstantExpression(i)); + child2.push_back(GetConstantExpression(i)); + types.push_back(type::TypeId::INTEGER); + } + + auto func = [](const std::vector &val) { + (void)val; + return type::ValueFactory::GetIntegerValue(5); + }; + function::BuiltInFuncType func_ptr = { static_cast(1), func }; + auto expr = std::make_shared("func", child1); + expr->SetBuiltinFunctionExpressionParameters(func_ptr, type::TypeId::INTEGER, types); + + AbsExpr_Container container = AbsExpr_Container(expr); + expression::AbstractExpression* rebuild = container.CopyWithChildren(child2); + + EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); + EXPECT_EQ(rebuild->GetChildrenSize(), expr->GetChildrenSize()); + + auto rebuild_c = dynamic_cast(rebuild); + EXPECT_TRUE(rebuild_c != nullptr); + EXPECT_EQ(rebuild_c->GetFuncName(), expr->GetFuncName()); + EXPECT_EQ(rebuild_c->GetFunc().op_id, expr->GetFunc().op_id); + EXPECT_EQ(rebuild_c->GetFunc().impl, expr->GetFunc().impl); + EXPECT_EQ(rebuild_c->GetArgTypes(), expr->GetArgTypes()); + EXPECT_EQ(rebuild_c->IsUDF(), expr->IsUDF()); + + for (int i = 0; i < 10; i++) { + auto l_child = rebuild_c->GetChild(i); + auto r_child = expr->GetChild(i); + EXPECT_TRUE(l_child != r_child && l_child->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto l_cast = dynamic_cast(l_child); + auto r_cast = dynamic_cast(r_child); + EXPECT_TRUE(l_cast != nullptr && r_cast != nullptr); + EXPECT_TRUE(l_cast->ExactlyEquals(*r_cast)); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(l_cast->GetValue()) == i); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(r_cast->GetValue()) == i); + } + + delete rebuild; +} + +} // namespace test +} // namespace peloton diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index 17f6002ebd0..c1247baeed6 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -347,14 +347,9 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; -<<<<<<< HEAD std::shared_ptr head_gexpr = std::make_shared( std::make_shared(Operator()), child_groups); -======= - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); ->>>>>>> templatize std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); @@ -440,14 +435,9 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { optimizer.TestInsertQueryTree(parse_tree, txn); std::vector child_groups = {gexpr->GetGroupID()}; -<<<<<<< HEAD std::shared_ptr head_gexpr = std::make_shared( std::make_shared(Operator()), child_groups); -======= - std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); ->>>>>>> templatize std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp index 48c4b0420b9..4ce84267d61 100644 --- a/test/optimizer/rewriter_test.cpp +++ b/test/optimizer/rewriter_test.cpp @@ -30,54 +30,6 @@ using namespace optimizer; class RewriterTests : public PelotonTest {}; -TEST_F(RewriterTests, ConvertAbsExpr) { - type::Value leftValue = type::ValueFactory::GetIntegerValue(1); - type::Value rightValue = type::ValueFactory::GetIntegerValue(2); - auto left = new expression::ConstantValueExpression(leftValue); - auto right = new expression::ConstantValueExpression(rightValue); - auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); - - Rewriter *rewriter = new Rewriter(); - - auto absexpr = rewriter->ConvertToAbsExpr(common); - EXPECT_TRUE(absexpr != nullptr); - EXPECT_TRUE(absexpr->Op().GetType() == ExpressionType::COMPARE_EQUAL); - EXPECT_TRUE(absexpr->Children().size() == 2); - - auto lefta = absexpr->Children()[0]; - auto righta = absexpr->Children()[1]; - EXPECT_TRUE(lefta != nullptr && righta != nullptr); - EXPECT_TRUE(lefta->Op().GetType() == righta->Op().GetType()); - EXPECT_TRUE(lefta->Op().GetType() == ExpressionType::VALUE_CONSTANT); - - auto left_cve = static_cast(lefta->Op().GetExpr()); - auto right_cve = static_cast(righta->Op().GetExpr()); - EXPECT_TRUE(left_cve == left); - EXPECT_TRUE(right_cve == right); - - // Try applying the rule - ComparatorElimination rule; - EXPECT_TRUE(rule.Check(absexpr, nullptr) == true); - - std::vector> transform; - rule.Transform(absexpr, transform, nullptr); - EXPECT_TRUE(transform.size() == 1); - - delete rewriter; - delete common; - - auto tr_expr = transform[0]; - EXPECT_TRUE(tr_expr != nullptr); - EXPECT_TRUE(tr_expr->Op().GetType() == ExpressionType::VALUE_CONSTANT); - EXPECT_TRUE(tr_expr->Children().size() == 0); - - auto tr_cve = static_cast(tr_expr->Op().GetExpr()); - EXPECT_TRUE(type::ValuePeeker::PeekBoolean(tr_cve->GetValue()) == false); - - // (TODO): hack to fix the memory leak bubbled from Transform() - delete tr_cve; -} - TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { type::Value leftValue = type::ValueFactory::GetIntegerValue(3); type::Value rightValue = type::ValueFactory::GetIntegerValue(2); @@ -154,25 +106,24 @@ TEST_F(RewriterTests, SimpleEqualityTree) { delete rewrote; } -// (TODO): delete this test once more rewriting rules implemented -TEST_F(RewriterTests, SimpleJunctionPreserve) { - // [AND] - // [=] [=] - // [4] [5] [3] [3] +TEST_F(RewriterTests, ComparativeOperatorTest) { + // [=] + // [<=] [>=] + // [4] [4] [5] [3] type::Value val4 = type::ValueFactory::GetIntegerValue(4); type::Value val5 = type::ValueFactory::GetIntegerValue(5); type::Value val3 = type::ValueFactory::GetIntegerValue(3); auto lb_left_child = new expression::ConstantValueExpression(val4); - auto lb_right_child = new expression::ConstantValueExpression(val5); - auto rb_left_child = new expression::ConstantValueExpression(val3); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); auto rb_right_child = new expression::ConstantValueExpression(val3); - auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, lb_left_child, lb_right_child); - auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, rb_left_child, rb_right_child); - auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lb, rb); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); Rewriter *rewriter = new Rewriter(); auto rewrote = rewriter->RewriteExpression(top); @@ -181,23 +132,12 @@ TEST_F(RewriterTests, SimpleJunctionPreserve) { delete top; EXPECT_TRUE(rewrote != nullptr); - EXPECT_TRUE(rewrote->GetChildrenSize() == 2); - EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::CONJUNCTION_AND); - - auto left = rewrote->GetChild(0); - auto right = rewrote->GetChild(1); - - EXPECT_TRUE(left != nullptr && right != nullptr); - EXPECT_TRUE(left->GetExpressionType() == ExpressionType::VALUE_CONSTANT); - EXPECT_TRUE(right->GetExpressionType() == ExpressionType::VALUE_CONSTANT); - - auto left_cast = dynamic_cast(left); - auto right_cast = dynamic_cast(right); - EXPECT_TRUE(left_cast->GetValueType() == type::TypeId::BOOLEAN); - EXPECT_TRUE(right_cast->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); - EXPECT_TRUE(type::ValuePeeker::PeekBoolean(left_cast->GetValue()) == false); - EXPECT_TRUE(type::ValuePeeker::PeekBoolean(right_cast->GetValue()) == true); + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); delete rewrote; } diff --git a/test/optimizer/rule_rewrite_test.cpp b/test/optimizer/rule_rewrite_test.cpp new file mode 100644 index 00000000000..49cee9af835 --- /dev/null +++ b/test/optimizer/rule_rewrite_test.cpp @@ -0,0 +1,521 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rule_rewrite_test.cpp +// +// Identification: test/optimizer/rule_rewrite_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include "common/harness.h" + +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class RuleRewriteTests : public PelotonTest { + public: + // Creates expresson: (A = X) AND (B = Y) + expression::AbstractExpression *CreateMultiLevelExpression(expression::AbstractExpression *a, + expression::AbstractExpression *x, + expression::AbstractExpression *b, + expression::AbstractExpression *y) { + auto left_eq = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, a->Copy(), x->Copy() + ); + + auto right_eq = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, b->Copy(), y->Copy() + ); + + return new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, left_eq, right_eq); + } + + expression::ConstantValueExpression *GetConstantExpression(int val) { + auto value = type::ValueFactory::GetIntegerValue(val); + return new expression::ConstantValueExpression(value); + } +}; + +TEST_F(RuleRewriteTests, ComparatorEliminationEqual) { + // (1 == 1) => (TRUE) + auto left = GetConstantExpression(1); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + // (1 == 2) => (FALSE) + auto left_f = GetConstantExpression(1); + auto right_f = GetConstantExpression(2); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + + delete rewrote; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationNotEqual) { + // (1 != 1) => (FALSE) + auto left = GetConstantExpression(1); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_NOTEQUAL, left, right); + + // (1 != 2) => (TRUE) + auto left_f = GetConstantExpression(1); + auto right_f = GetConstantExpression(2); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_NOTEQUAL, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThan) { + // (0 < 1) => (TRUE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left, right); + + // (1 < 1) => (FALSE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left_ef, right_ef); + + // (2 < 1) => (FALSE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationGreaterThan) { + // (0 > 1) => (FALSE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left, right); + + // (1 < 1) => (FALSE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left_ef, right_ef); + + // (2 > 1) => (TRUE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThanOrEqualTo) { + // (0 <= 1) => (TRUE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left, right); + + // (1 <= 1) => (TRUE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left_ef, right_ef); + + // (2 <= 1) => (FALSE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationGreaterThanOrEqualTo) { + // (0 >= 1) => (FALSE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left, right); + + // (1 >= 1) => (TRUE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left_ef, right_ef); + + // (2 >= 1) => (TRUE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThanOrEqualToNull) { + auto valNULL = type::ValueFactory::GetNullValueByType(type::TypeId::INTEGER); + + // 0 <= NULL => NULL + auto left = GetConstantExpression(2); + auto right = new expression::ConstantValueExpression(valNULL); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + delete rewriter; + delete equal; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + + auto value = casted->GetValue(); + EXPECT_TRUE(value.GetTypeId() == type::TypeId::BOOLEAN); + EXPECT_TRUE(value.IsNull()); + + delete rewrote; +} + +TEST_F(RuleRewriteTests, TVEqualTwoCVFalseTransform) { + auto cv1 = GetConstantExpression(1); + auto cv2 = GetConstantExpression(2); + auto tv_base = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base: (A.B = 1) AND (A.B = 2) + auto base = CreateMultiLevelExpression(tv_base, cv1, tv_base, cv2); + + // Inverse: (1 = A.B) AND (2 = A.B) + auto inverse = CreateMultiLevelExpression(cv1, tv_base, cv2, tv_base); + + // Inner Flip Left: (1 = A.B) AND (A.B = 2) + auto if_left = CreateMultiLevelExpression(cv1, tv_base, tv_base, cv2); + + // Inner Flip Right: (A.B = 1) AND (2 = A.B) + auto if_right = CreateMultiLevelExpression(tv_base, cv1, cv2, tv_base); + + std::vector rewrites; + rewrites.push_back(rewriter->RewriteExpression(base)); + rewrites.push_back(rewriter->RewriteExpression(inverse)); + rewrites.push_back(rewriter->RewriteExpression(if_left)); + rewrites.push_back(rewriter->RewriteExpression(if_right)); + delete rewriter; + delete cv1; + delete cv2; + delete tv_base; + delete base; + delete inverse; + delete if_left; + delete if_right; + + for (auto expr : rewrites) { + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(expr->GetChildrenSize() == 0); + + auto expr_val = dynamic_cast(expr); + EXPECT_TRUE(expr_val != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(expr_val->GetValue()) == false); + } + + while (!rewrites.empty()) { + auto expr = rewrites.back(); + rewrites.pop_back(); + delete expr; + } +} + +TEST_F(RuleRewriteTests, TVEqualTwoCVTrueTransform) { + auto cv1 = GetConstantExpression(1); + auto tv_base = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base: (A.B = 1) AND (A.B = 1) + auto base = CreateMultiLevelExpression(tv_base, cv1, tv_base, cv1); + + // Inverse: (1 = A.B) AND (1 = A.B) + auto inverse = CreateMultiLevelExpression(cv1, tv_base, cv1, tv_base); + + // Inner Flip Left: (1 = A.B) AND (A.B = 1) + auto if_left = CreateMultiLevelExpression(cv1, tv_base, tv_base, cv1); + + // Inner Flip Right: (A.B = 1) AND (1 = A.B) + auto if_right = CreateMultiLevelExpression(tv_base, cv1, cv1, tv_base); + + std::vector rewrites; + rewrites.push_back(rewriter->RewriteExpression(base)); + rewrites.push_back(rewriter->RewriteExpression(inverse)); + rewrites.push_back(rewriter->RewriteExpression(if_left)); + rewrites.push_back(rewriter->RewriteExpression(if_right)); + delete rewriter; + delete cv1; + delete base; + delete inverse; + delete if_left; + delete if_right; + + for (auto expr : rewrites) { + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_tv = expr->GetModifiableChild(0); + auto tv = dynamic_cast(left_tv); + EXPECT_TRUE(tv != nullptr); + EXPECT_TRUE(tv->ExactlyEquals(*tv_base)); + + auto right_cv = expr->GetModifiableChild(1); + auto cv = dynamic_cast(right_cv); + EXPECT_TRUE(cv != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(cv->GetValue()) == 1); + } + + while (!rewrites.empty()) { + auto expr = rewrites.back(); + rewrites.pop_back(); + delete expr; + } + + delete tv_base; +} + +TEST_F(RuleRewriteTests, TransitiveClosureUnableTest) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + auto tv_base2 = new expression::TupleValueExpression("C", "A"); + auto tv_base3 = new expression::TupleValueExpression("D", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (B = C) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base2, tv_base3); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_eq = expr->GetModifiableChild(0); + auto right_eq = expr->GetModifiableChild(1); + EXPECT_TRUE(left_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(right_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(left_eq->GetChildrenSize() == 2); + EXPECT_TRUE(right_eq->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(left_eq->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(left_eq->GetModifiableChild(1)); + auto rl_tv = dynamic_cast(right_eq->GetModifiableChild(0)); + auto rr_tv = dynamic_cast(right_eq->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr && rl_tv != nullptr && rr_tv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + EXPECT_TRUE(rl_tv->ExactlyEquals(*tv_base2)); + EXPECT_TRUE(rr_tv->ExactlyEquals(*tv_base3)); + + delete cv1; + delete tv_base1; + delete tv_base2; + delete tv_base3; + delete expr; +} + +TEST_F(RuleRewriteTests, TransitiveClosureRewrite) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + auto tv_base2 = new expression::TupleValueExpression("C", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (A = B) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base1, tv_base2); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_eq = expr->GetModifiableChild(0); + auto right_eq = expr->GetModifiableChild(1); + EXPECT_TRUE(left_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(right_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(left_eq->GetChildrenSize() == 2); + EXPECT_TRUE(right_eq->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(left_eq->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(left_eq->GetModifiableChild(1)); + auto rl_cv = dynamic_cast(right_eq->GetModifiableChild(0)); + auto rr_tv = dynamic_cast(right_eq->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr && rl_cv != nullptr && rr_tv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + EXPECT_TRUE(rl_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(rr_tv->ExactlyEquals(*tv_base2)); + + delete cv1; + delete tv_base1; + delete tv_base2; + delete expr; +} + +TEST_F(RuleRewriteTests, TransitiveClosureHalfTrue) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (A = B) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base1, tv_base1); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(expr->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(expr->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + + delete cv1; + delete tv_base1; + delete expr; +} + +} // namespace test +} // namespace peloton From e3ac1ba195c9550ae7d921e79b3b505a97c92e86 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Sun, 5 May 2019 07:24:30 +0000 Subject: [PATCH 08/14] Ported not-null foreign keys and short-circuiting --- src/binder/bind_node_visitor.cpp | 17 ++ src/include/common/internal_types.h | 8 + .../expression/tuple_value_expression.h | 12 +- src/include/optimizer/rule_rewrite.h | 47 +++ src/optimizer/binding.cpp | 4 + src/optimizer/rule.cpp | 9 + src/optimizer/rule_rewrite.cpp | 223 ++++++++++++++ test/optimizer/rewriter_test.cpp | 279 ++++++++++++++++++ 8 files changed, 598 insertions(+), 1 deletion(-) diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 2ccd85dcdb2..183c73e817d 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -12,6 +12,8 @@ #include "binder/bind_node_visitor.h" #include "catalog/catalog.h" +#include "catalog/table_catalog.h" +#include "catalog/column_catalog.h" #include "expression/expression_util.h" #include "expression/star_expression.h" #include "type/type_id.h" @@ -250,6 +252,21 @@ void BindNodeVisitor::Visit(expression::TupleValueExpression *expr) { expr->SetColName(col_name); expr->SetValueType(value_type); expr->SetBoundOid(col_pos_tuple); + + // TODO(esargent): Uncommenting the following code makes AddressSanitizer get mad at me with a + // heap buffer overflow whenever I try a query that references the same non-null attribute multiple + // times (e.g. 'SELECT id FROM t WHERE id < 3 AND id > 1'). Leaving it commented out prevents the + // memory error, but then this prevents the is_not_null flag of a tuple expression from being + // populated in some cases (specifically, when the expression's table name is initially empty). + + //if (table_obj == nullptr) { + // LOG_DEBUG("Extracting regular table object"); + // BinderContext::GetRegularTableObj(context_, table_name, table_obj, depth); + //} + + if (table_obj != nullptr) { + expr->SetIsNotNull(table_obj->GetColumnCatalogEntry(std::get<2>(col_pos_tuple), false)->IsNotNull()); + } } } diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 39c9647b2ef..e81ec101b02 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1393,6 +1393,14 @@ enum class RuleType : uint32_t { TV_EQUALITY_WITH_TWO_CV, // (A.B = x) AND (A.B = y) where x/y are constant TRANSITIVE_CLOSURE_CONSTANT, // (A.B = x) AND (A.B = C.D) + // Boolean short-circuit rules + AND_SHORT_CIRCUIT, // (FALSE AND B) + OR_SHORT_CIRCUIT, // (TRUE OR B) + + // Catalog-based NULL/NON-NULL rules + NULL_LOOKUP_ON_NOT_NULL_COLUMN, + NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN, + // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/expression/tuple_value_expression.h b/src/include/expression/tuple_value_expression.h index dab5d1e4ddd..16f37ae8645 100644 --- a/src/include/expression/tuple_value_expression.h +++ b/src/include/expression/tuple_value_expression.h @@ -79,6 +79,10 @@ class TupleValueExpression : public AbstractExpression { tuple_idx_ = tuple_idx; } + inline void SetIsNotNull(bool is_not_null) { + is_not_null_ = is_not_null; + } + /** * @brief Attribute binding * @param binding_contexts @@ -116,6 +120,8 @@ class TupleValueExpression : public AbstractExpression { if ((table_name_.empty() xor other.table_name_.empty()) || col_name_.empty() xor other.col_name_.empty()) return false; + if (GetIsNotNull() != other.GetIsNotNull()) + return false; bool res = bound_obj_id_ == other.bound_obj_id_; if (!table_name_.empty() && !other.table_name_.empty()) res = table_name_ == other.table_name_ && res; @@ -151,6 +157,8 @@ class TupleValueExpression : public AbstractExpression { bool GetIsBound() const { return is_bound_; } + bool GetIsNotNull() const { return is_not_null_; } + const std::tuple &GetBoundOid() const { return bound_obj_id_; } @@ -185,7 +193,8 @@ class TupleValueExpression : public AbstractExpression { value_idx_(other.value_idx_), tuple_idx_(other.tuple_idx_), table_name_(other.table_name_), - col_name_(other.col_name_) {} + col_name_(other.col_name_), + is_not_null_(other.is_not_null_) {} // Bound flag bool is_bound_ = false; @@ -196,6 +205,7 @@ class TupleValueExpression : public AbstractExpression { int tuple_idx_; std::string table_name_; std::string col_name_; + bool is_not_null_ = false; const planner::AttributeInfo *ai_; }; diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index 8df83556626..ab739aa177d 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -20,6 +20,9 @@ namespace peloton { namespace optimizer { +using GroupExprTemplate = GroupExpression; +using OptimizeContext = OptimizeContext; + /* Rules are applied from high to low priority */ enum class RulePriority : int { HIGH = 3, @@ -71,5 +74,49 @@ class TransitiveClosureConstantTransform: public Rule { OptimizeContext *context) const override; }; +class AndShortCircuit: public Rule { + public: + AndShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class OrShortCircuit: public Rule { + public: + OrShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class NullLookupOnNotNullColumn: public Rule { + public: + NullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +class NotNullLookupOnNotNullColumn: public Rule { + public: + NotNullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 807c4c42f94..986710ab0ab 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -12,11 +12,15 @@ #include "optimizer/binding.h" +#include + #include "common/logger.h" #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" #include "optimizer/absexpr_expression.h" #include "expression/group_marker_expression.h" +#include "expression/abstract_expression.h" +#include "expression/tuple_value_expression.h" namespace peloton { namespace optimizer { diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 3baec8dba86..fc6e814e58c 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -76,6 +76,15 @@ RuleSet::RuleSet() { AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new AndShortCircuit()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new OrShortCircuit()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NullLookupOnNotNullColumn()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NotNullLookupOnNotNullColumn()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + // Define transformation/implementation rules AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index b804c08e488..7bcb3aaadf1 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -399,5 +399,228 @@ void TransitiveClosureConstantTransform::Transform(std::shared_ptr) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int AndShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} + +bool AndShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void AndShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (FALSE AND ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + LOG_DEBUG("fjdsklafjksdjflkadsjf"); + + // Only transform the expression if we're ANDing a FALSE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsFalse()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + + +OrShortCircuit::OrShortCircuit() { + type_ = RuleType::OR_SHORT_CIRCUIT; + + // (FALSE AND ) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_OR); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int OrShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::HIGH); +} + +bool OrShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void OrShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_OR); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + // Only transform the expression if we're ANDing a TRUE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsTrue()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + + +NullLookupOnNotNullColumn::NullLookupOnNotNullColumn() { + type_ = RuleType::NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool NullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + PELOTON_ASSERT(child_c != nullptr); + + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [FALSE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + +NotNullLookupOnNotNullColumn::NotNullLookupOnNotNullColumn() { + type_ = RuleType::NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NOT NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NOT_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NotNullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(RulePriority::LOW); +} + +bool NotNullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NotNullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NOT_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [TRUE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + } // namespace optimizer } // namespace peloton diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp index 4ce84267d61..c9faacca9c0 100644 --- a/test/optimizer/rewriter_test.cpp +++ b/test/optimizer/rewriter_test.cpp @@ -18,6 +18,7 @@ #include "expression/constant_value_expression.h" #include "expression/comparison_expression.h" #include "expression/tuple_value_expression.h" +#include "expression/operator_expression.h" #include "type/value_factory.h" #include "type/value_peeker.h" #include "optimizer/rule_rewrite.h" @@ -142,5 +143,283 @@ TEST_F(RewriterTests, ComparativeOperatorTest) { delete rewrote; } +TEST_F(RewriterTests, BasicAndShortCircuitTest) { + + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [AND] + // [FALSE] [=] + // [X] [3] + // + // Intended output: [FALSE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_false); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [AND] + // [TRUE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_true); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_AND); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, BasicOrShortCircuitTest) { + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [OR] + // [TRUE] [=] + // [X] [3] + // + // Intended output: [TRUE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_true); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [OR] + // [FALSE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_false); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_OR); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, AndShortCircuitComparatorEliminationMixTest) { + // [AND] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: FALSE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + + +TEST_F(RewriterTests, OrShortCircuitComparatorEliminationMixTest) { + // [OR] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: TRUE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + + +TEST_F(RewriterTests, NotNullColumnsTest) { + + // First, build rewriter to be used in all test cases + Rewriter *rewriter = new Rewriter(); + + // [T.X IS NULL], where X is a non-NULL column in table T + // Intended output: FALSE + + auto child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + auto root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + auto rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), false); + + delete root; + delete rewrote; + + // [T.X IS NOT NULL], where X is a non-NULL column in table T + // Intended output: TRUE + + child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), true); + + delete root; + delete rewrote; + + // [T.Y IS NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NULL); + + delete root; + delete rewrote; + + // [T.Y IS NOT NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NOT_NULL); + + delete root; + delete rewrote; + + delete rewriter; +} + + } // namespace test } // namespace peloton From 4984dbf6d5ce295532b399dfdf1ea8fa9452d0e9 Mon Sep 17 00:00:00 2001 From: Erik Sargent Date: Tue, 14 May 2019 18:49:55 -0400 Subject: [PATCH 09/14] Added documentation on rule_rewrite.h --- src/include/optimizer/rule_rewrite.h | 48 ++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index ab739aa177d..d82f97da898 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -30,6 +30,15 @@ enum class RulePriority : int { LOW = 1 }; +/* + * Comparator Elimination: When two constant values are compared against + * each other (==, !=, >, <, >=, <=), the comparison expression gets rewritten + * to either TRUE or FALSE, depending on whether the constants agree with the + * comparison or not + * Examples: + * "1 == 2" ==> "FALSE" + * "3 <= 4" ==> "TRUE" + */ class ComparatorElimination: public Rule { public: ComparatorElimination(RuleType rule, ExpressionType root); @@ -41,6 +50,13 @@ class ComparatorElimination: public Rule { OptimizeContext *context) const override; }; +/* + * Equivalent Transform: When a symmetric operator (==, !=, AND, OR) has two + * children, the comparison expression gets its arguments flipped. + * Examples: + * "T.X != 3" ==> "3 != T.X" + * "(T.X == 1) AND (T.Y == 2)" ==> "(T.Y == 2) AND (T.X == 1)" + */ class EquivalentTransform: public Rule { public: EquivalentTransform(RuleType rule, ExpressionType root); @@ -52,6 +68,13 @@ class EquivalentTransform: public Rule { OptimizeContext *context) const override; }; +/* + * Tuple Value Equality with Two Constant Values: When the same tuple reference + * is checked against two distinct constant values, the expression is rewritten + * to FALSE + * Example: + * "(T.X == 3) AND (T.X == 4)" ==> "FALSE" + */ class TVEqualityWithTwoCVTransform: public Rule { public: TVEqualityWithTwoCVTransform(); @@ -63,6 +86,13 @@ class TVEqualityWithTwoCVTransform: public Rule { OptimizeContext *context) const override; }; +/* + * Transitive Closure w/ Constants: When two tuple references are compared against each + * other and one of the tuple references is compared to a constant, the expression + * swaps out the doubled tuple reference for the constant. + * Example: + * "(T.X == Q.Y) AND (T.X == 6)" ==> "(6 == Q.Y) AND (T.X == 6)" + */ class TransitiveClosureConstantTransform: public Rule { public: TransitiveClosureConstantTransform(); @@ -74,6 +104,9 @@ class TransitiveClosureConstantTransform: public Rule { OptimizeContext *context) const override; }; +/* + * And Short Circuiting: Anything AND FALSE is rewritten to FALSE. + */ class AndShortCircuit: public Rule { public: AndShortCircuit(); @@ -85,6 +118,9 @@ class AndShortCircuit: public Rule { OptimizeContext *context) const override; }; +/* + * Or Short Circuiting: Anything OR TRUE is rewritten to TRUE. + */ class OrShortCircuit: public Rule { public: OrShortCircuit(); @@ -96,6 +132,12 @@ class OrShortCircuit: public Rule { OptimizeContext *context) const override; }; +/* + * Null Lookup on Not Null Column: Asking if a tuple reference is NULL is rewritten + * to FALSE only when the catalog says that that attribute has a non-NULL constraint. + * Example: + * "T.X IS NULL" ==> "FALSE" (assuming T.X is a non-NULL attribute) + */ class NullLookupOnNotNullColumn: public Rule { public: NullLookupOnNotNullColumn(); @@ -107,6 +149,12 @@ class NullLookupOnNotNullColumn: public Rule { OptimizeContext *context) const override; }; +/* + * Not Null Lookup on Not Null Column: Asking if a tuple reference is NOT NULL is rewritten + * to TRUE only when the catalog says that that attribute has a non-NULL constraint. + * Example: + * "T.X IS NOT NULL" ==> "TRUE" (assuming T.X is a non-NULL attribute) + */ class NotNullLookupOnNotNullColumn: public Rule { public: NotNullLookupOnNotNullColumn(); From 787cdfdc73a571fa9a3e3b446f3ce6a6cf3c0ccc Mon Sep 17 00:00:00 2001 From: Erik Sargent Date: Tue, 14 May 2019 18:53:15 -0400 Subject: [PATCH 10/14] Switched order of include expressions in abstract_node_expression.h --- src/include/optimizer/abstract_node_expression.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/optimizer/abstract_node_expression.h b/src/include/optimizer/abstract_node_expression.h index 01dbc40683c..85736616032 100644 --- a/src/include/optimizer/abstract_node_expression.h +++ b/src/include/optimizer/abstract_node_expression.h @@ -12,11 +12,11 @@ #pragma once -#include "optimizer/abstract_node.h" - #include #include +#include "optimizer/abstract_node.h" + namespace peloton { namespace optimizer { From 23a7fbf248c46d55f261cf58cc4ef4f670bb9290 Mon Sep 17 00:00:00 2001 From: Erik Sargent Date: Tue, 14 May 2019 19:02:22 -0400 Subject: [PATCH 11/14] Swapped order of #includes in abstract_node_expression.h --- src/include/optimizer/abstract_node_expression.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/optimizer/abstract_node_expression.h b/src/include/optimizer/abstract_node_expression.h index 01dbc40683c..85736616032 100644 --- a/src/include/optimizer/abstract_node_expression.h +++ b/src/include/optimizer/abstract_node_expression.h @@ -12,11 +12,11 @@ #pragma once -#include "optimizer/abstract_node.h" - #include #include +#include "optimizer/abstract_node.h" + namespace peloton { namespace optimizer { From da839e840f40b0a25dd328386d9ad1a47b521b45 Mon Sep 17 00:00:00 2001 From: William Zhang <17zhangw@gmail.com> Date: Tue, 14 May 2019 19:19:31 -0400 Subject: [PATCH 12/14] GroupMarkerExpression, Rewriter, AbstractNode docs --- .../expression/group_marker_expression.h | 9 ++++ src/include/optimizer/abstract_node.h | 30 ++++++++++++++ src/include/optimizer/rewriter.h | 41 +++++++++++++++++++ src/optimizer/rule_rewrite.cpp | 4 +- test/optimizer/rewriter_test.cpp | 6 ++- 5 files changed, 86 insertions(+), 4 deletions(-) diff --git a/src/include/expression/group_marker_expression.h b/src/include/expression/group_marker_expression.h index 25c717f16e5..cceb46c07d2 100644 --- a/src/include/expression/group_marker_expression.h +++ b/src/include/expression/group_marker_expression.h @@ -28,6 +28,15 @@ namespace expression { // GroupMarkerExpression //===----------------------------------------------------------------------===// +/** + * When binding expressions to specific patterns, we allow for a "wildcard". + * This GroupMarkerExpression serves to encapsulate and represent an expression + * that was bound successfully to a "wildcard" pattern node. + * + * This class contains a single GroupID that can be used as a lookup into the + * Memo class for the actual expression. In short, this GroupMarkerExpression + * serves as an indirection wrapper pointing to the actual expression. + */ class GroupMarkerExpression : public AbstractExpression { public: GroupMarkerExpression(optimizer::GroupID group_id) : diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h index 3b4a1cb16b1..b9d1e43ad9d 100644 --- a/src/include/optimizer/abstract_node.h +++ b/src/include/optimizer/abstract_node.h @@ -88,29 +88,59 @@ struct AbstractNode { ~AbstractNode() {} + /** + * Accepts a visitor + * @param v Visitor + */ virtual void Accept(OperatorVisitor *v) const = 0; + /** + * @returns Name fo the Node + */ virtual std::string GetName() const = 0; // TODO(ncx): dependence on OpType and ExpressionType (ideally abstracted away) + /** + * @returns OpType of the Node + */ virtual OpType GetOpType() const = 0; + /** + * @returns ExpressionType of the Node + */ virtual ExpressionType GetExpType() const = 0; + /** + * @returns whether node represents a logical operator / expression + */ virtual bool IsLogical() const = 0; + /** + * @returns whether node represents a physical operator + */ virtual bool IsPhysical() const = 0; + /** + * Hashes the AbstractNode + * @returns hash + */ virtual hash_t Hash() const { // TODO(ncx): hash should work for ExpressionType nodes OpType t = GetOpType(); return HashUtil::Hash(&t); } + /** + * Base definition of whether two AbstractNodes are equal + * Function simply checks whether OpType/ExpType match + */ virtual bool operator==(const AbstractNode &r) { return GetOpType() == r.GetOpType() && GetExpType() == r.GetExpType(); } + /** + * @returns whether the contained node is null or not + */ virtual bool IsDefined() const { return node != nullptr; } template diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h index 161692a1a70..2f8668c66fd 100644 --- a/src/include/optimizer/rewriter.h +++ b/src/include/optimizer/rewriter.h @@ -25,22 +25,63 @@ namespace optimizer { class Rewriter { public: + /** + * Default constructor + */ Rewriter(); + + /** + * Resets the internal state of the rewriter + */ void Reset(); DISALLOW_COPY_AND_MOVE(Rewriter); + /** + * Gets the OptimizerMetadata used by the rewriter + * @returns internal OptimizerMetadata + */ OptimizerMetadata &GetMetadata() { return metadata_; } + /** + * Rewrites an expression by applying applicable rules + * @param expr AbstractExpression to rewrite + * @returns rewriteen AbstractExpression + */ expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); private: + /** + * Creates an AbstractExpression from the Memo used internally + * @param root_group GroupID of the root group to begin building from + * @returns AbstractExpression from the stored groups + */ expression::AbstractExpression* RebuildExpression(int root_group); + + /** + * Performs a single rewrite pass on the epxression + * @param root_group_id GroupID of the group to start rewriting from + */ void RewriteLoop(int root_group_id); + /** + * Converts AbstractExpression into internal rewriter representation + * @param expr expression to convert + * @returns shared pointer to rewriter internal representation + */ std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + + /** + * Records the original groups (subtrees) of the AbstractExpression. + * From the recorded information, it is possible to rebuild the expression. + * @param expr expression whose groups to record + * @returns GroupExpression representing the root of the expression + */ std::shared_ptr RecordTreeGroups(const expression::AbstractExpression *expr); + /** + * OptimizerMetadata that we leverage + */ OptimizerMetadata metadata_; }; diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index 7bcb3aaadf1..c1286392cfc 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -46,7 +46,7 @@ bool ComparatorElimination::Check(std::shared_ptr plan, void ComparatorElimination::Transform(std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + OptimizeContext *context) const { (void)transformed; (void)context; @@ -148,7 +148,7 @@ bool EquivalentTransform::Check(std::shared_ptr plan, void EquivalentTransform::Transform(std::shared_ptr input, std::vector> &transformed, - UNUSED_ATTRIBUTE OptimizeContext *context) const { + OptimizeContext *context) const { (void)transformed; (void)context; diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp index c9faacca9c0..ac9ea9f9d7c 100644 --- a/test/optimizer/rewriter_test.cpp +++ b/test/optimizer/rewriter_test.cpp @@ -32,6 +32,7 @@ using namespace optimizer; class RewriterTests : public PelotonTest {}; TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { + // 3 = 2 ==> FALSE type::Value leftValue = type::ValueFactory::GetIntegerValue(3); type::Value rightValue = type::ValueFactory::GetIntegerValue(2); auto left = new expression::ConstantValueExpression(leftValue); @@ -52,6 +53,7 @@ TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { } TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { + // 4 = 4 ==> TRUE type::Value leftValue = type::ValueFactory::GetIntegerValue(4); type::Value rightValue = type::ValueFactory::GetIntegerValue(4); auto left = new expression::ConstantValueExpression(leftValue); @@ -73,7 +75,7 @@ TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { TEST_F(RewriterTests, SimpleEqualityTree) { // [=] - // [=] [=] + // [=] [=] ==> FALSE // [4] [5] [3] [3] type::Value val4 = type::ValueFactory::GetIntegerValue(4); type::Value val5 = type::ValueFactory::GetIntegerValue(5); @@ -109,7 +111,7 @@ TEST_F(RewriterTests, SimpleEqualityTree) { TEST_F(RewriterTests, ComparativeOperatorTest) { // [=] - // [<=] [>=] + // [<=] [>=] ==> TRUE // [4] [4] [5] [3] type::Value val4 = type::ValueFactory::GetIntegerValue(4); type::Value val5 = type::ValueFactory::GetIntegerValue(5); From a0315c3b29f5c6869e07369c9904399f4e963f77 Mon Sep 17 00:00:00 2001 From: Newton Xie Date: Tue, 14 May 2019 23:07:34 -0400 Subject: [PATCH 13/14] Naming conventions and style fixes. --- src/include/optimizer/absexpr_expression.h | 35 +++++------ src/include/optimizer/rewriter.h | 2 +- src/include/optimizer/rule_rewrite.h | 8 +-- src/optimizer/absexpr_expression.cpp | 2 +- src/optimizer/binding.cpp | 4 +- src/optimizer/group.cpp | 2 +- src/optimizer/memo.cpp | 2 +- src/optimizer/rewriter.cpp | 14 ++--- src/optimizer/rule_rewrite.cpp | 72 +++++++++++----------- test/optimizer/absexpr_test.cpp | 28 ++++----- 10 files changed, 83 insertions(+), 86 deletions(-) diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h index d5c6f098be7..17dc35d0ad0 100644 --- a/src/include/optimizer/absexpr_expression.h +++ b/src/include/optimizer/absexpr_expression.h @@ -23,22 +23,22 @@ namespace peloton { namespace optimizer { -// AbsExpr_Container and AbsExpr_Expression provides and serves an analogous purpose -// to Operator and OperatorExpression. Each AbsExpr_Container wraps a single -// AbstractExpression node with the children placed inside the AbsExpr_Expression. +// AbsExprNode and AbsExprExpression provides and serves an analogous purpose +// to Operator and OperatorExpression. Each AbsExprNode wraps a single +// AbstractExpression with the children placed inside the AbsExprExpression. // // This is done to export the correct interface from the wrapped AbstractExpression -// to the rest of the core rule/optimizer code/logic. -class AbsExpr_Container: public AbstractNode { +// to the rest of the core optimizer logic. +class AbsExprNode: public AbstractNode { public: // Default constructors - AbsExpr_Container() = default; - AbsExpr_Container(const AbsExpr_Container &other): + AbsExprNode() = default; + AbsExprNode(const AbsExprNode &other): AbstractNode() { expr = other.expr; } - AbsExpr_Container(std::shared_ptr expr_) { + AbsExprNode(std::shared_ptr expr_) { expr = expr_; } @@ -78,8 +78,7 @@ class AbsExpr_Container: public AbstractNode { if (IsDefined()) { return expr->GetExpressionName(); } - - return "Undefined"; + throw OptimizerException("Undefined expression name."); } hash_t Hash() const { @@ -91,14 +90,14 @@ class AbsExpr_Container: public AbstractNode { bool operator==(const AbstractNode &r) { if (r.GetExpType() != ExpressionType::INVALID) { - const AbsExpr_Container &cnt = dynamic_cast(r); + const AbsExprNode &cnt = dynamic_cast(r); return (*this == cnt); } return false; } - bool operator==(const AbsExpr_Container &r) { + bool operator==(const AbsExprNode &r) { if (IsDefined() && r.IsDefined()) { //TODO(): proper equality check when migrate to terrier // Equality check relies on performing the following: @@ -108,7 +107,7 @@ class AbsExpr_Container: public AbstractNode { // are children-less, operator== provides sufficient checking. // The reason behind why the children-less guarantee is required, // is that the "real" children are actually tracked by the - // AbsExpr_Expression class. + // AbsExprExpression class. return false; } else if (!IsDefined() && !r.IsDefined()) { return true; @@ -130,17 +129,17 @@ class AbsExpr_Container: public AbstractNode { }; -class AbsExpr_Expression: public AbstractNodeExpression { +class AbsExprExpression: public AbstractNodeExpression { public: - AbsExpr_Expression(std::shared_ptr n) { - std::shared_ptr cnt = std::dynamic_pointer_cast(n); + AbsExprExpression(std::shared_ptr n) { + std::shared_ptr cnt = std::dynamic_pointer_cast(n); PELOTON_ASSERT(cnt != nullptr); node = n; } // Disallow copy and move constructor - DISALLOW_COPY_AND_MOVE(AbsExpr_Expression); + DISALLOW_COPY_AND_MOVE(AbsExprExpression); void PushChild(std::shared_ptr op) { children.push_back(op); @@ -156,7 +155,7 @@ class AbsExpr_Expression: public AbstractNodeExpression { const std::shared_ptr Node() const { // Integrity constraint - std::shared_ptr cnt = std::dynamic_pointer_cast(node); + std::shared_ptr cnt = std::dynamic_pointer_cast(node); PELOTON_ASSERT(cnt != nullptr); return node; diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h index 2f8668c66fd..271f6611032 100644 --- a/src/include/optimizer/rewriter.h +++ b/src/include/optimizer/rewriter.h @@ -69,7 +69,7 @@ class Rewriter { * @param expr expression to convert * @returns shared pointer to rewriter internal representation */ - std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); /** * Records the original groups (subtrees) of the AbstractExpression. diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h index ab739aa177d..055d6d6a0b2 100644 --- a/src/include/optimizer/rule_rewrite.h +++ b/src/include/optimizer/rule_rewrite.h @@ -24,11 +24,9 @@ using GroupExprTemplate = GroupExpression; using OptimizeContext = OptimizeContext; /* Rules are applied from high to low priority */ -enum class RulePriority : int { - HIGH = 3, - MEDIUM = 2, - LOW = 1 -}; +#define HIGH_PRIORITY 3 +#define MEDIUM_PRIORITY 2 +#define LOW_PRIORITY 1 class ComparatorElimination: public Rule { public: diff --git a/src/optimizer/absexpr_expression.cpp b/src/optimizer/absexpr_expression.cpp index c0e8d5ca8da..30122603ffe 100644 --- a/src/optimizer/absexpr_expression.cpp +++ b/src/optimizer/absexpr_expression.cpp @@ -19,7 +19,7 @@ namespace peloton { namespace optimizer { -expression::AbstractExpression *AbsExpr_Container::CopyWithChildren(std::vector children) { +expression::AbstractExpression *AbsExprNode::CopyWithChildren(std::vector children) { // Pre-compute left and right expression::AbstractExpression *left = nullptr; expression::AbstractExpression *right = nullptr; diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 986710ab0ab..40786696326 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -86,7 +86,7 @@ std::shared_ptr GroupBindingIterator::Next() { current_item_index_ = num_group_items_; auto expr = std::make_shared(group_id_); - return std::make_shared(std::make_shared(expr)); + return std::make_shared(std::make_shared(expr)); } return current_iterator_->Next(); @@ -112,7 +112,7 @@ GroupExprBindingIterator::GroupExprBindingIterator( if (gexpr->Node()->GetOpType() != OpType::Undefined) { current_binding_ = std::make_shared(gexpr->Node()); } else { - current_binding_ = std::make_shared(gexpr->Node()); + current_binding_ = std::make_shared(gexpr->Node()); } const std::vector &child_groups = gexpr->GetChildGroupIDs(); diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index b7b9c851ece..71e0a8af77a 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -32,7 +32,7 @@ void Group::AddExpression(std::shared_ptr expr, // Additional assertion checks for AddExpression() with AST rewriting // TODO(ncx): get group expression type - // if (std::is_same::value) { + // if (std::is_same::value) { // PELOTON_ASSERT(!enforced); // PELOTON_ASSERT(!expr->Op().IsPhysical()); // } diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index c61cdc81d25..c01ffc19e02 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -44,7 +44,7 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, if (gexpr->Node()->GetOpType() == OpType::Undefined && gexpr->Node()->GetExpType() == ExpressionType::GROUP_MARKER) { - auto abs_node = std::dynamic_pointer_cast(gexpr->Node()); + auto abs_node = std::dynamic_pointer_cast(gexpr->Node()); PELOTON_ASSERT(abs_node != nullptr); auto gm_expr = std::dynamic_pointer_cast(abs_node->GetExpr()); diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp index 4a440af56c6..5378287dbe7 100644 --- a/src/optimizer/rewriter.cpp +++ b/src/optimizer/rewriter.cpp @@ -86,24 +86,24 @@ expression::AbstractExpression* Rewriter::RebuildExpression(int root) { child_exprs.push_back(child); } - std::shared_ptr c = std::dynamic_pointer_cast(expr->Node()); + std::shared_ptr c = std::dynamic_pointer_cast(expr->Node()); PELOTON_ASSERT(c != nullptr); return c->CopyWithChildren(child_exprs); } -std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { +std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { // TODO(): remove the Copy invocation when in terrier since terrier uses shared_ptr // // This Copy() is not very efficient at all. but within Peloton, this is the only way - // to present a std::shared_ptr to the AbsExpr_Container/Expression classes. In terrier, + // to present a std::shared_ptr to the AbsExprNode/Expression classes. In terrier, // this Copy() is *definitely* not needed because the AbstractExpression there already // utilizes std::shared_ptr properly. std::shared_ptr copy = std::shared_ptr(expr->Copy()); - // Create current AbsExpr_Expression - auto container = std::make_shared(copy); - auto expression = std::make_shared(container); + // Create current AbsExprExpression + auto container = std::make_shared(copy); + auto expression = std::make_shared(container); // Convert all the children size_t child_count = expr->GetChildrenSize(); @@ -116,7 +116,7 @@ std::shared_ptr Rewriter::ConvertToAbsExpr(const expression: } std::shared_ptr Rewriter::RecordTreeGroups(const expression::AbstractExpression *expr) { - std::shared_ptr exp = ConvertToAbsExpr(expr); + std::shared_ptr exp = ConvertToAbsExpr(expr); std::shared_ptr gexpr; metadata_.RecordTransformedExpression(exp, gexpr); return gexpr; diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp index c1286392cfc..c1439a847e3 100644 --- a/src/optimizer/rule_rewrite.cpp +++ b/src/optimizer/rule_rewrite.cpp @@ -34,7 +34,7 @@ int ComparatorElimination::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::MEDIUM); + return static_cast(MEDIUM_PRIORITY); } bool ComparatorElimination::Check(std::shared_ptr plan, @@ -54,8 +54,8 @@ void ComparatorElimination::Transform(std::shared_ptr in // Since the binding succeeded, there are guaranteed to be two children. PELOTON_ASSERT(input->Children().size() == 2); - auto left_abs = std::dynamic_pointer_cast(input->Children()[0]->Node()); - auto right_abs = std::dynamic_pointer_cast(input->Children()[1]->Node()); + auto left_abs = std::dynamic_pointer_cast(input->Children()[0]->Node()); + auto right_abs = std::dynamic_pointer_cast(input->Children()[1]->Node()); PELOTON_ASSERT(left_abs != nullptr && right_abs != nullptr); auto left = left_abs->GetExpr(); @@ -105,8 +105,8 @@ void ComparatorElimination::Transform(std::shared_ptr in // Create the replacement type::Value val = type::ValueFactory::GetBooleanValue(cmp); auto expr = std::make_shared(val); - auto container = std::make_shared(AbsExpr_Container(expr)); - auto shared = std::make_shared(container); + auto container = std::make_shared(AbsExprNode(expr)); + auto shared = std::make_shared(container); transformed.push_back(shared); } @@ -136,7 +136,7 @@ int EquivalentTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::HIGH); + return static_cast(HIGH_PRIORITY); } bool EquivalentTransform::Check(std::shared_ptr plan, @@ -163,8 +163,8 @@ void EquivalentTransform::Transform(std::shared_ptr inpu // The children do not strictly matter anymore auto type = match_pattern->GetExpType(); auto expr = std::make_shared(type); - auto a_expr = std::make_shared(expr); - auto shared = std::make_shared(a_expr); + auto a_expr = std::make_shared(expr); + auto shared = std::make_shared(a_expr); // Create flipped ordering at logical level shared->PushChild(right); @@ -204,7 +204,7 @@ TVEqualityWithTwoCVTransform::TVEqualityWithTwoCVTransform() { int TVEqualityWithTwoCVTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::LOW); + return static_cast(LOW_PRIORITY); } bool TVEqualityWithTwoCVTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -245,10 +245,10 @@ void TVEqualityWithTwoCVTransform::Transform(std::shared_ptrNode()->GetExpType() == ExpressionType::VALUE_TUPLE); PELOTON_ASSERT(r_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); - auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); - auto r_tv_c = std::dynamic_pointer_cast(r_tv->Node()); - auto l_cv_c = std::dynamic_pointer_cast(l_cv->Node()); - auto r_cv_c = std::dynamic_pointer_cast(r_cv->Node()); + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_c = std::dynamic_pointer_cast(r_tv->Node()); + auto l_cv_c = std::dynamic_pointer_cast(l_cv->Node()); + auto r_cv_c = std::dynamic_pointer_cast(r_cv->Node()); PELOTON_ASSERT(l_tv_c != nullptr && r_tv_c != nullptr); PELOTON_ASSERT(l_cv_c != nullptr && r_cv_c != nullptr); @@ -278,7 +278,7 @@ void TVEqualityWithTwoCVTransform::Transform(std::shared_ptr(val); - auto abs_expr = std::make_shared(std::make_shared(AbsExpr_Container(constant))); + auto abs_expr = std::make_shared(std::make_shared(AbsExprNode(constant))); transformed.push_back(abs_expr); } @@ -314,7 +314,7 @@ TransitiveClosureConstantTransform::TransitiveClosureConstantTransform() { int TransitiveClosureConstantTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::LOW); + return static_cast(LOW_PRIORITY); } bool TransitiveClosureConstantTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -355,9 +355,9 @@ void TransitiveClosureConstantTransform::Transform(std::shared_ptrNode()->GetExpType() == ExpressionType::VALUE_TUPLE); PELOTON_ASSERT(r_tv_r->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); - auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); - auto r_tv_l_c = std::dynamic_pointer_cast(r_tv_l->Node()); - auto r_tv_r_c = std::dynamic_pointer_cast(r_tv_r->Node()); + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_l_c = std::dynamic_pointer_cast(r_tv_l->Node()); + auto r_tv_r_c = std::dynamic_pointer_cast(r_tv_r->Node()); PELOTON_ASSERT(l_tv_c != nullptr && r_tv_l_c != nullptr && r_tv_r_c != nullptr); auto l_tv_expr = l_tv_c->GetExpr(); @@ -379,7 +379,7 @@ void TransitiveClosureConstantTransform::Transform(std::shared_ptr(r_eq->Node()); + auto new_right_eq = std::make_shared(r_eq->Node()); // At this stage, we have knowledge that C.D != E.F if (l_tv_expr->ExactlyEquals(*r_tv_l_expr)) { @@ -393,7 +393,7 @@ void TransitiveClosureConstantTransform::Transform(std::shared_ptr(input->Node()); + auto abs_expr = std::make_shared(input->Node()); abs_expr->PushChild(new_left_eq); abs_expr->PushChild(new_right_eq); transformed.push_back(abs_expr); @@ -419,7 +419,7 @@ AndShortCircuit::AndShortCircuit() { int AndShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::HIGH); + return static_cast(HIGH_PRIORITY); } bool AndShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -443,7 +443,7 @@ void AndShortCircuit::Transform(std::shared_ptr input, PELOTON_ASSERT(left->Children().size() == 0); PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); - std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); PELOTON_ASSERT(left_c != nullptr); auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); @@ -455,8 +455,8 @@ void AndShortCircuit::Transform(std::shared_ptr input, if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsFalse()) { type::Value val_false = type::ValueFactory::GetBooleanValue(false); std::shared_ptr false_expr = std::make_shared(val_false); - std::shared_ptr false_cnt = std::make_shared(false_expr); - std::shared_ptr false_container = std::make_shared(false_cnt); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); transformed.push_back(false_container); } } @@ -477,7 +477,7 @@ OrShortCircuit::OrShortCircuit() { int OrShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::HIGH); + return static_cast(HIGH_PRIORITY); } bool OrShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -501,7 +501,7 @@ void OrShortCircuit::Transform(std::shared_ptr input, PELOTON_ASSERT(left->Children().size() == 0); PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); - std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); PELOTON_ASSERT(left_c != nullptr); auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); @@ -511,8 +511,8 @@ void OrShortCircuit::Transform(std::shared_ptr input, if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsTrue()) { type::Value val_true = type::ValueFactory::GetBooleanValue(true); std::shared_ptr true_expr = std::make_shared(val_true); - std::shared_ptr true_cnt = std::make_shared(true_expr); - std::shared_ptr true_container = std::make_shared(true_cnt); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); transformed.push_back(true_container); } } @@ -531,7 +531,7 @@ NullLookupOnNotNullColumn::NullLookupOnNotNullColumn() { int NullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::LOW); + return static_cast(LOW_PRIORITY); } bool NullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -555,7 +555,7 @@ void NullLookupOnNotNullColumn::Transform(std::shared_ptrChildren().size() == 0); PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); - std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); PELOTON_ASSERT(child_c != nullptr); auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); @@ -565,8 +565,8 @@ void NullLookupOnNotNullColumn::Transform(std::shared_ptrGetIsNotNull()) { type::Value val_false = type::ValueFactory::GetBooleanValue(false); std::shared_ptr false_expr = std::make_shared(val_false); - std::shared_ptr false_cnt = std::make_shared(false_expr); - std::shared_ptr false_container = std::make_shared(false_cnt); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); transformed.push_back(false_container); } } @@ -584,7 +584,7 @@ NotNullLookupOnNotNullColumn::NotNullLookupOnNotNullColumn() { int NotNullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)group_expr; (void)context; - return static_cast(RulePriority::LOW); + return static_cast(LOW_PRIORITY); } bool NotNullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { @@ -608,7 +608,7 @@ void NotNullLookupOnNotNullColumn::Transform(std::shared_ptrChildren().size() == 0); PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); - std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); // Only transform into [TRUE] if the tuple value expression is specifically non-NULL, @@ -616,8 +616,8 @@ void NotNullLookupOnNotNullColumn::Transform(std::shared_ptrGetIsNotNull()) { type::Value val_true = type::ValueFactory::GetBooleanValue(true); std::shared_ptr true_expr = std::make_shared(val_true); - std::shared_ptr true_cnt = std::make_shared(true_expr); - std::shared_ptr true_container = std::make_shared(true_cnt); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); transformed.push_back(true_container); } } diff --git a/test/optimizer/absexpr_test.cpp b/test/optimizer/absexpr_test.cpp index 1e8d233c82d..6c8ccdc917e 100644 --- a/test/optimizer/absexpr_test.cpp +++ b/test/optimizer/absexpr_test.cpp @@ -79,7 +79,7 @@ TEST_F(AbsExprTest, CompareTest) { auto right = new expression::ParameterValueExpression(1); for (auto type : compares) { auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); - AbsExpr_Container op = AbsExpr_Container(cmp_expr); + AbsExprNode op = AbsExprNode(cmp_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); EXPECT_TRUE(rebuilt != nullptr); @@ -114,7 +114,7 @@ TEST_F(AbsExprTest, ConjunctionTest) { auto right = new expression::ConstantValueExpression(fval); for (auto type : compares) { auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); - AbsExpr_Container op = AbsExpr_Container(cmp_expr); + AbsExprNode op = AbsExprNode(cmp_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); EXPECT_TRUE(rebuilt != nullptr); @@ -156,7 +156,7 @@ TEST_F(AbsExprTest, OperatorTest) { auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), right->Copy()); op_expr->DeduceExpressionType(); - AbsExpr_Container op = AbsExpr_Container(op_expr); + AbsExprNode op = AbsExprNode(op_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); EXPECT_TRUE(rebuilt != nullptr); rebuilt->DeduceExpressionType(); @@ -180,7 +180,7 @@ TEST_F(AbsExprTest, OperatorTest) { auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), nullptr); op_expr->DeduceExpressionType(); - AbsExpr_Container op = AbsExpr_Container(op_expr); + AbsExprNode op = AbsExprNode(op_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); EXPECT_TRUE(rebuilt != nullptr); rebuilt->DeduceExpressionType(); @@ -205,7 +205,7 @@ TEST_F(AbsExprTest, OperatorUnaryMinusTest) { auto left = GetConstantExpression(25); auto unary = std::make_shared(left->Copy()); - AbsExpr_Container op = AbsExpr_Container(unary); + AbsExprNode op = AbsExprNode(unary); expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); EXPECT_TRUE(rebuilt != nullptr); @@ -220,7 +220,7 @@ TEST_F(AbsExprTest, OperatorUnaryMinusTest) { TEST_F(AbsExprTest, StarTest) { auto expr = std::make_shared(); - AbsExpr_Container op = AbsExpr_Container(expr); + AbsExprNode op = AbsExprNode(expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); EXPECT_EQ(*expr, *rebuilt); @@ -230,7 +230,7 @@ TEST_F(AbsExprTest, StarTest) { TEST_F(AbsExprTest, ValueConstantTest) { auto cv_expr = dynamic_cast(GetConstantExpression(721)); auto expr = std::shared_ptr(cv_expr); - AbsExpr_Container op = AbsExpr_Container(expr); + AbsExprNode op = AbsExprNode(expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); EXPECT_EQ(*expr, *rebuilt); // this does not check value @@ -249,7 +249,7 @@ TEST_F(AbsExprTest, ValueConstantTest) { TEST_F(AbsExprTest, ValueParameterTest) { auto expr = std::make_shared(15); - AbsExpr_Container op = AbsExpr_Container(expr); + AbsExprNode op = AbsExprNode(expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); EXPECT_EQ(*expr, *rebuilt); // does not check value_idx_ @@ -264,7 +264,7 @@ TEST_F(AbsExprTest, ValueTupleTest) { expr_col->SetTupleValueExpressionParams(type::TypeId::INTEGER, 1, 1); expr_col->SetTableName("tbl"); - AbsExpr_Container op = AbsExpr_Container(expr_col); + AbsExprNode op = AbsExprNode(expr_col); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); EXPECT_EQ(*expr_col, *rebuilt); // checks tbl_name, col_name @@ -293,7 +293,7 @@ TEST_F(AbsExprTest, AggregateNodeTest) { auto agg_expr = std::make_shared(type, true, child->Copy()); agg_expr->DeduceExpressionType(); - AbsExpr_Container op = AbsExpr_Container(agg_expr); + AbsExprNode op = AbsExprNode(agg_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({child->Copy()}); EXPECT_TRUE(rebuilt != nullptr); @@ -317,7 +317,7 @@ TEST_F(AbsExprTest, AggregateNodeTest) { agg_expr->DeduceExpressionType(); EXPECT_TRUE(agg_expr->GetExpressionType() == ExpressionType::AGGREGATE_COUNT_STAR); - AbsExpr_Container op = AbsExpr_Container(agg_expr); + AbsExprNode op = AbsExprNode(agg_expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); rebuilt->DeduceExpressionType(); @@ -343,7 +343,7 @@ TEST_F(AbsExprTest, CaseExpressionTest) { clauses.push_back(expression::CaseExpression::WhenClause(std::move(where3), std::move(res3))); auto expr = std::make_shared(type::TypeId::INTEGER, clauses, std::move(def_c)); - AbsExpr_Container op = AbsExpr_Container(expr); + AbsExprNode op = AbsExprNode(expr); expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); // Checks every clause except for ConstantValue values @@ -396,7 +396,7 @@ TEST_F(AbsExprTest, SubqueryTest) { auto expr = std::make_shared(); expr->SetSubSelect(sel); - AbsExpr_Container container = AbsExpr_Container(expr); + AbsExprNode container = AbsExprNode(expr); expression::AbstractExpression *rebuild = container.CopyWithChildren({}); EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); @@ -426,7 +426,7 @@ TEST_F(AbsExprTest, FunctionExpressionTest) { auto expr = std::make_shared("func", child1); expr->SetBuiltinFunctionExpressionParameters(func_ptr, type::TypeId::INTEGER, types); - AbsExpr_Container container = AbsExpr_Container(expr); + AbsExprNode container = AbsExprNode(expr); expression::AbstractExpression* rebuild = container.CopyWithChildren(child2); EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); From 010407b30d22908ffa0ed2948e336b822277fc26 Mon Sep 17 00:00:00 2001 From: Newton Xie Date: Tue, 14 May 2019 23:31:53 -0400 Subject: [PATCH 14/14] Addressing TODOs before code merge. --- src/include/optimizer/abstract_node.h | 8 ++++---- src/include/optimizer/input_column_deriver.h | 1 - src/include/optimizer/operator_node.h | 2 +- src/include/optimizer/rule.h | 2 -- src/optimizer/group.cpp | 9 ++++----- src/optimizer/input_column_deriver.cpp | 2 -- src/optimizer/memo.cpp | 3 +-- src/optimizer/rule.cpp | 1 - 8 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h index b9d1e43ad9d..902ae575967 100644 --- a/src/include/optimizer/abstract_node.h +++ b/src/include/optimizer/abstract_node.h @@ -79,7 +79,6 @@ enum class OpType { //===--------------------------------------------------------------------===// // Abstract Node //===--------------------------------------------------------------------===// -//TODO(ncx): dependence on OperatorVisitor class OperatorVisitor; struct AbstractNode { @@ -88,6 +87,7 @@ struct AbstractNode { ~AbstractNode() {} + //TODO: dependence on OperatorVisitor should ideally be abstracted away /** * Accepts a visitor * @param v Visitor @@ -99,7 +99,7 @@ struct AbstractNode { */ virtual std::string GetName() const = 0; - // TODO(ncx): dependence on OpType and ExpressionType (ideally abstracted away) + // TODO: dependencies on OpType and ExpressionType also not ideal /** * @returns OpType of the Node */ @@ -125,9 +125,9 @@ struct AbstractNode { * @returns hash */ virtual hash_t Hash() const { - // TODO(ncx): hash should work for ExpressionType nodes OpType t = GetOpType(); - return HashUtil::Hash(&t); + ExpressionType u = GetExpType(); + return t != OpType::Undefined ? HashUtil::Hash(&t) : HashUtil::Hash(&u); } /** diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index 86ce053ba27..4b218a9b99b 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -100,7 +100,6 @@ class InputColumnDeriver : public OperatorVisitor { * @brief Provide all tuple value expressions needed in the expression */ void ScanHelper(); - // TODO(ncx): BaseOperatorNode void AggregateHelper(const AbstractNode *); void JoinHelper(const AbstractNode *op); diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index ec896978a75..f450a2cea74 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -27,7 +27,7 @@ namespace optimizer { //===--------------------------------------------------------------------===// class OperatorVisitor; -// Curiously recurring template pattern +// TODO: should probably use a new AbstractBaseNode interface, not AbstractNode template struct OperatorNode : public AbstractNode { OperatorNode() : AbstractNode(nullptr) {} diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index cbd21a79738..eb1711342ed 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -30,7 +30,6 @@ class Rule { public: virtual ~Rule(){}; - // TODO(ncx): pattern std::shared_ptr GetMatchPattern() const { return match_pattern; } bool IsPhysical() const { @@ -92,7 +91,6 @@ class Rule { inline uint32_t GetRuleIdx() { return static_cast(type_); } protected: - // TODO(ncx): pattern std::shared_ptr match_pattern; RuleType type_; }; diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 71e0a8af77a..55f6093be75 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -31,11 +31,10 @@ void Group::AddExpression(std::shared_ptr expr, bool enforced) { // Additional assertion checks for AddExpression() with AST rewriting - // TODO(ncx): get group expression type - // if (std::is_same::value) { - // PELOTON_ASSERT(!enforced); - // PELOTON_ASSERT(!expr->Op().IsPhysical()); - // } + if (expr->Node()->GetExpType() != ExpressionType::INVALID) { + PELOTON_ASSERT(!enforced); + PELOTON_ASSERT(!expr->Node()->IsPhysical()); + } // Do duplicate detection expr->SetGroupID(id_); diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index ddc14152cf8..c2f94cc2d3d 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -200,7 +200,6 @@ void InputColumnDeriver::ScanHelper() { output_cols, {}}; } -// TODO(ncx): BaseOperatorNode void InputColumnDeriver::AggregateHelper(const AbstractNode *op) { ExprSet input_cols_set; ExprMap output_cols_map; @@ -270,7 +269,6 @@ void InputColumnDeriver::AggregateHelper(const AbstractNode *op) { output_cols, {input_cols}}; } -// TODO(ncx): BaseOperatorNode void InputColumnDeriver::JoinHelper(const AbstractNode *op) { const vector *join_conds = nullptr; const vector> *left_keys = nullptr; diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index c01ffc19e02..9f649cc00ae 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -115,7 +115,6 @@ GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { std::unordered_set table_aliases; auto op_type = gexpr->Node()->GetOpType(); - // TODO(ncx): specialize (if not OpType, then just add new group) if (op_type == OpType::Get) { // For base group, the table alias can get directly from logical get const LogicalGet *logical_get = gexpr->Node()->As(); @@ -124,7 +123,7 @@ GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { const LogicalQueryDerivedGet *query_get = gexpr->Node()->As(); table_aliases.insert(query_get->table_alias); - } else { + } else if (op_type != OpType::Undefined) { // For other groups, need to aggregate the table alias from children for (auto child_group_id : gexpr->GetChildGroupIDs()) { Group *child_group = GetGroupByID(child_group_id); diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index fc6e814e58c..bb2518d8732 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -40,7 +40,6 @@ int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { return LOG_PROMISE; } -// TODO(ncx): best way to specialize for constructors? RuleSet::RuleSet() { // Comparator Elimination related rules std::vector> comp_elim_pairs = {