diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 2ccd85dcdb2..183c73e817d 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -12,6 +12,8 @@ #include "binder/bind_node_visitor.h" #include "catalog/catalog.h" +#include "catalog/table_catalog.h" +#include "catalog/column_catalog.h" #include "expression/expression_util.h" #include "expression/star_expression.h" #include "type/type_id.h" @@ -250,6 +252,21 @@ void BindNodeVisitor::Visit(expression::TupleValueExpression *expr) { expr->SetColName(col_name); expr->SetValueType(value_type); expr->SetBoundOid(col_pos_tuple); + + // TODO(esargent): Uncommenting the following code makes AddressSanitizer get mad at me with a + // heap buffer overflow whenever I try a query that references the same non-null attribute multiple + // times (e.g. 'SELECT id FROM t WHERE id < 3 AND id > 1'). Leaving it commented out prevents the + // memory error, but then this prevents the is_not_null flag of a tuple expression from being + // populated in some cases (specifically, when the expression's table name is initially empty). + + //if (table_obj == nullptr) { + // LOG_DEBUG("Extracting regular table object"); + // BinderContext::GetRegularTableObj(context_, table_name, table_obj, depth); + //} + + if (table_obj != nullptr) { + expr->SetIsNotNull(table_obj->GetColumnCatalogEntry(std::get<2>(col_pos_tuple), false)->IsNotNull()); + } } } diff --git a/src/include/common/internal_types.h b/src/include/common/internal_types.h index 96b45f9e42b..e81ec101b02 100644 --- a/src/include/common/internal_types.h +++ b/src/include/common/internal_types.h @@ -1,15 +1,3 @@ -//===----------------------------------------------------------------------===// -// -// Peloton -// -// internal_types.h -// -// Identification: src/include/common/internal_types.h -// -// Copyright (c) 2015-2018, Carnegie Mellon University Database Group -// -//===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // // Peloton @@ -258,7 +246,12 @@ enum class ExpressionType { // ----------------------------- // Miscellaneous // ----------------------------- - CAST = 600 + CAST = 600, + + // ----------------------------- + // Rewriter-specific identifier + // ----------------------------- + GROUP_MARKER = 721 }; // When short_str is true, return a short version of ExpressionType string @@ -1383,6 +1376,31 @@ enum class RuleType : uint32_t { PULL_FILTER_THROUGH_MARK_JOIN, PULL_FILTER_THROUGH_AGGREGATION, + // AST rewrite rules (logical -> logical) + // Removes ConstantValue =/!=//<=/>= ConstantValue + CONSTANT_COMPARE_EQUAL, + CONSTANT_COMPARE_NOTEQUAL, + CONSTANT_COMPARE_LESSTHAN, + CONSTANT_COMPARE_GREATERTHAN, + CONSTANT_COMPARE_LESSTHANOREQUALTO, + CONSTANT_COMPARE_GREATERTHANOREQUALTO, + + // Logical equivalent + EQUIV_AND, + EQUIV_OR, + EQUIV_COMPARE_EQUAL, + + TV_EQUALITY_WITH_TWO_CV, // (A.B = x) AND (A.B = y) where x/y are constant + TRANSITIVE_CLOSURE_CONSTANT, // (A.B = x) AND (A.B = C.D) + + // Boolean short-circuit rules + AND_SHORT_CIRCUIT, // (FALSE AND B) + OR_SHORT_CIRCUIT, // (TRUE OR B) + + // Catalog-based NULL/NON-NULL rules + NULL_LOOKUP_ON_NOT_NULL_COLUMN, + NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN, + // Place holder to generate number of rules compile time NUM_RULES diff --git a/src/include/expression/abstract_expression.h b/src/include/expression/abstract_expression.h index 6acdf7b2751..5cfd3614cb8 100644 --- a/src/include/expression/abstract_expression.h +++ b/src/include/expression/abstract_expression.h @@ -111,6 +111,24 @@ class AbstractExpression : public Printable { children_[index].reset(expr); } + void ClearChildren() { + // The rewriter copies the AbstractExpression presented to the rewriter. + // This function is used by the rewriter to properly wipe all the children + // of AbstractExpression once the AbstractExpression tree has been converted + // to the intermediary AbsExpr_Container/Expression tree. + // + // This function should only be invoked on copied AbstractExpressions and + // never on original ones passed into the [rewriter] otherwise we would + // violate immutable properties. We do not believe this function is strictly + // necessary in terrier, however this function does serve some usefulness. + // + // This allows us to reduce our memory footprint while also providing better + // implementation constraints within the rewriter (i.e: the rewriter should + // not be trying to operate directly on AbstractExpression but rather on the + // intermediary representation). + children_.clear(); + } + void SetExpressionType(ExpressionType type) { exp_type_ = type; } ////////////////////////////////////////////////////////////////////////////// diff --git a/src/include/expression/group_marker_expression.h b/src/include/expression/group_marker_expression.h new file mode 100644 index 00000000000..cceb46c07d2 --- /dev/null +++ b/src/include/expression/group_marker_expression.h @@ -0,0 +1,73 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// group_marker_expression.h +// +// Identification: src/include/expression/group_marker_expression.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "expression/abstract_expression.h" +#include "optimizer/group_expression.h" +#include "util/hash_util.h" + +namespace peloton { + +namespace executor { +class ExecutorContext; +} // namespace executor + +namespace expression { + +//===----------------------------------------------------------------------===// +// GroupMarkerExpression +//===----------------------------------------------------------------------===// + +/** + * When binding expressions to specific patterns, we allow for a "wildcard". + * This GroupMarkerExpression serves to encapsulate and represent an expression + * that was bound successfully to a "wildcard" pattern node. + * + * This class contains a single GroupID that can be used as a lookup into the + * Memo class for the actual expression. In short, this GroupMarkerExpression + * serves as an indirection wrapper pointing to the actual expression. + */ +class GroupMarkerExpression : public AbstractExpression { + public: + GroupMarkerExpression(optimizer::GroupID group_id) : + AbstractExpression(ExpressionType::GROUP_MARKER), + group_id_(group_id) {}; + + optimizer::GroupID GetGroupID() { return group_id_; } + + AbstractExpression *Copy() const override { + return new GroupMarkerExpression(group_id_); + } + + type::Value Evaluate(const AbstractTuple *tuple1, + const AbstractTuple *tuple2, + executor::ExecutorContext *context) const { + (void)tuple1; + (void)tuple2; + (void)context; + PELOTON_ASSERT(0); + } + + void Accept(SqlNodeVisitor *) { + PELOTON_ASSERT(0); + } + + protected: + optimizer::GroupID group_id_; + + GroupMarkerExpression(const GroupMarkerExpression &other) + : AbstractExpression(other), group_id_(other.group_id_) {} +}; + +} // namespace expression +} // namespace peloton diff --git a/src/include/expression/tuple_value_expression.h b/src/include/expression/tuple_value_expression.h index dab5d1e4ddd..16f37ae8645 100644 --- a/src/include/expression/tuple_value_expression.h +++ b/src/include/expression/tuple_value_expression.h @@ -79,6 +79,10 @@ class TupleValueExpression : public AbstractExpression { tuple_idx_ = tuple_idx; } + inline void SetIsNotNull(bool is_not_null) { + is_not_null_ = is_not_null; + } + /** * @brief Attribute binding * @param binding_contexts @@ -116,6 +120,8 @@ class TupleValueExpression : public AbstractExpression { if ((table_name_.empty() xor other.table_name_.empty()) || col_name_.empty() xor other.col_name_.empty()) return false; + if (GetIsNotNull() != other.GetIsNotNull()) + return false; bool res = bound_obj_id_ == other.bound_obj_id_; if (!table_name_.empty() && !other.table_name_.empty()) res = table_name_ == other.table_name_ && res; @@ -151,6 +157,8 @@ class TupleValueExpression : public AbstractExpression { bool GetIsBound() const { return is_bound_; } + bool GetIsNotNull() const { return is_not_null_; } + const std::tuple &GetBoundOid() const { return bound_obj_id_; } @@ -185,7 +193,8 @@ class TupleValueExpression : public AbstractExpression { value_idx_(other.value_idx_), tuple_idx_(other.tuple_idx_), table_name_(other.table_name_), - col_name_(other.col_name_) {} + col_name_(other.col_name_), + is_not_null_(other.is_not_null_) {} // Bound flag bool is_bound_ = false; @@ -196,6 +205,7 @@ class TupleValueExpression : public AbstractExpression { int tuple_idx_; std::string table_name_; std::string col_name_; + bool is_not_null_ = false; const planner::AttributeInfo *ai_; }; diff --git a/src/include/optimizer/absexpr_expression.h b/src/include/optimizer/absexpr_expression.h new file mode 100644 index 00000000000..17dc35d0ad0 --- /dev/null +++ b/src/include/optimizer/absexpr_expression.h @@ -0,0 +1,175 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_expression.h +// +// Identification: src/include/optimizer/absexpr_expression.h +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/abstract_node_expression.h" +#include "optimizer/abstract_node.h" +#include "expression/abstract_expression.h" +#include "expression/conjunction_expression.h" +#include "expression/comparison_expression.h" +#include "expression/constant_value_expression.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +// AbsExprNode and AbsExprExpression provides and serves an analogous purpose +// to Operator and OperatorExpression. Each AbsExprNode wraps a single +// AbstractExpression with the children placed inside the AbsExprExpression. +// +// This is done to export the correct interface from the wrapped AbstractExpression +// to the rest of the core optimizer logic. +class AbsExprNode: public AbstractNode { + public: + // Default constructors + AbsExprNode() = default; + AbsExprNode(const AbsExprNode &other): + AbstractNode() { + expr = other.expr; + } + + AbsExprNode(std::shared_ptr expr_) { + expr = expr_; + } + + OpType GetOpType() const { + return OpType::Undefined; + } + + // Return operator type + ExpressionType GetExpType() const { + if (IsDefined()) { + return expr->GetExpressionType(); + } + return ExpressionType::INVALID; + } + + std::shared_ptr GetExpr() const { + return expr; + } + + // Dummy Accept + void Accept(OperatorVisitor *v) const { + (void)v; + PELOTON_ASSERT(0); + } + + // Operator contains Logical node + bool IsLogical() const { + return true; + } + + // Operator contains Physical node + bool IsPhysical() const { + return false; + } + + std::string GetName() const { + if (IsDefined()) { + return expr->GetExpressionName(); + } + throw OptimizerException("Undefined expression name."); + } + + hash_t Hash() const { + if (IsDefined()) { + return expr->Hash(); + } + return 0; + } + + bool operator==(const AbstractNode &r) { + if (r.GetExpType() != ExpressionType::INVALID) { + const AbsExprNode &cnt = dynamic_cast(r); + return (*this == cnt); + } + + return false; + } + + bool operator==(const AbsExprNode &r) { + if (IsDefined() && r.IsDefined()) { + //TODO(): proper equality check when migrate to terrier + // Equality check relies on performing the following: + // - Check each node's ExpressionType + // - Check other parameters for a given node + // We believe that in terrier so long as the AbstractExpression + // are children-less, operator== provides sufficient checking. + // The reason behind why the children-less guarantee is required, + // is that the "real" children are actually tracked by the + // AbsExprExpression class. + return false; + } else if (!IsDefined() && !r.IsDefined()) { + return true; + } + return false; + } + + // Operator contains physical or logical operator node + bool IsDefined() const { + return expr != nullptr; + } + + //TODO(): Function should use std::shared_ptr when migrate to terrier + expression::AbstractExpression *CopyWithChildren(std::vector children); + + private: + // Internal wrapped AbstractExpression + std::shared_ptr expr; +}; + + +class AbsExprExpression: public AbstractNodeExpression { + public: + AbsExprExpression(std::shared_ptr n) { + std::shared_ptr cnt = std::dynamic_pointer_cast(n); + PELOTON_ASSERT(cnt != nullptr); + + node = n; + } + + // Disallow copy and move constructor + DISALLOW_COPY_AND_MOVE(AbsExprExpression); + + void PushChild(std::shared_ptr op) { + children.push_back(op); + } + + void PopChild() { + children.pop_back(); + } + + const std::vector> &Children() const { + return children; + } + + const std::shared_ptr Node() const { + // Integrity constraint + std::shared_ptr cnt = std::dynamic_pointer_cast(node); + PELOTON_ASSERT(cnt != nullptr); + + return node; + } + + const std::string GetInfo() const { + //TODO(): create proper info statement? + return ""; + } + + private: + std::shared_ptr node; + std::vector> children; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/abstract_node.h b/src/include/optimizer/abstract_node.h new file mode 100644 index 00000000000..902ae575967 --- /dev/null +++ b/src/include/optimizer/abstract_node.h @@ -0,0 +1,159 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// abstract_node.h +// +// Identification: src/include/optimizer/abstract_node.h +// +// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/property_set.h" +#include "util/hash_util.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +enum class OpType { + Undefined = 0, + // Special match operators + Leaf, + // Logical ops + Get, + LogicalExternalFileGet, + LogicalQueryDerivedGet, + LogicalProjection, + LogicalFilter, + LogicalMarkJoin, + LogicalDependentJoin, + LogicalSingleJoin, + InnerJoin, + LeftJoin, + RightJoin, + OuterJoin, + SemiJoin, + LogicalAggregateAndGroupBy, + LogicalInsert, + LogicalInsertSelect, + LogicalDelete, + LogicalUpdate, + LogicalLimit, + LogicalDistinct, + LogicalExportExternalFile, + // Separate between logical and physical ops + LogicalPhysicalDelimiter, + // Physical ops + DummyScan, /* Dummy Physical Op for SELECT without FROM*/ + SeqScan, + IndexScan, + ExternalFileScan, + QueryDerivedScan, + OrderBy, + PhysicalLimit, + Distinct, + InnerNLJoin, + LeftNLJoin, + RightNLJoin, + OuterNLJoin, + InnerHashJoin, + LeftHashJoin, + RightHashJoin, + OuterHashJoin, + Insert, + InsertSelect, + Delete, + Update, + Aggregate, + HashGroupBy, + SortGroupBy, + ExportExternalFile, +}; + +//===--------------------------------------------------------------------===// +// Abstract Node +//===--------------------------------------------------------------------===// +class OperatorVisitor; + +struct AbstractNode { + AbstractNode() {} + AbstractNode(std::shared_ptr node) : node(node) {} + + ~AbstractNode() {} + + //TODO: dependence on OperatorVisitor should ideally be abstracted away + /** + * Accepts a visitor + * @param v Visitor + */ + virtual void Accept(OperatorVisitor *v) const = 0; + + /** + * @returns Name fo the Node + */ + virtual std::string GetName() const = 0; + + // TODO: dependencies on OpType and ExpressionType also not ideal + /** + * @returns OpType of the Node + */ + virtual OpType GetOpType() const = 0; + + /** + * @returns ExpressionType of the Node + */ + virtual ExpressionType GetExpType() const = 0; + + /** + * @returns whether node represents a logical operator / expression + */ + virtual bool IsLogical() const = 0; + + /** + * @returns whether node represents a physical operator + */ + virtual bool IsPhysical() const = 0; + + /** + * Hashes the AbstractNode + * @returns hash + */ + virtual hash_t Hash() const { + OpType t = GetOpType(); + ExpressionType u = GetExpType(); + return t != OpType::Undefined ? HashUtil::Hash(&t) : HashUtil::Hash(&u); + } + + /** + * Base definition of whether two AbstractNodes are equal + * Function simply checks whether OpType/ExpType match + */ + virtual bool operator==(const AbstractNode &r) { + return GetOpType() == r.GetOpType() && GetExpType() == r.GetExpType(); + } + + /** + * @returns whether the contained node is null or not + */ + virtual bool IsDefined() const { return node != nullptr; } + + template + const T *As() const { + if (node && typeid(*node) == typeid(T)) { + return (const T *)node.get(); + } + return nullptr; + } + + protected: + std::shared_ptr node; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/abstract_node_expression.h b/src/include/optimizer/abstract_node_expression.h new file mode 100644 index 00000000000..85736616032 --- /dev/null +++ b/src/include/optimizer/abstract_node_expression.h @@ -0,0 +1,44 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// abstract_node_expression.h +// +// Identification: src/include/optimizer/abstract_node_expression.h +// +// Copyright (c) 2015-19, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "optimizer/abstract_node.h" + +namespace peloton { +namespace optimizer { + +//===--------------------------------------------------------------------===// +// Abstract Node Expression +//===--------------------------------------------------------------------===// +class AbstractNodeExpression { + public: + AbstractNodeExpression() {} + + ~AbstractNodeExpression() {} + + virtual void PushChild(std::shared_ptr child) = 0; + + virtual void PopChild() = 0; + + virtual const std::vector> &Children() const = 0; + + virtual const std::shared_ptr Node() const = 0; + + virtual const std::string GetInfo() const = 0; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/binding.h b/src/include/optimizer/binding.h index 7a6d772813d..233e27f3aea 100644 --- a/src/include/optimizer/binding.h +++ b/src/include/optimizer/binding.h @@ -12,13 +12,14 @@ #pragma once +#include "operator_expression.h" #include "optimizer/operator_node.h" #include "optimizer/group.h" #include "optimizer/pattern.h" + #include #include #include -#include "operator_expression.h" namespace peloton { namespace optimizer { @@ -37,7 +38,7 @@ class BindingIterator { virtual bool HasNext() = 0; - virtual std::shared_ptr Next() = 0; + virtual std::shared_ptr Next() = 0; protected: Memo &memo_; @@ -45,12 +46,11 @@ class BindingIterator { class GroupBindingIterator : public BindingIterator { public: - GroupBindingIterator(Memo& memo, GroupID id, - std::shared_ptr pattern); + GroupBindingIterator(Memo& memo, GroupID id, std::shared_ptr pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupID group_id_; @@ -58,19 +58,20 @@ class GroupBindingIterator : public BindingIterator { Group *target_group_; size_t num_group_items_; + // Internal function for HasNext() + bool HasNextBinding(); + size_t current_item_index_; std::unique_ptr current_iterator_; }; class GroupExprBindingIterator : public BindingIterator { public: - GroupExprBindingIterator(Memo& memo, - GroupExpression *gexpr, - std::shared_ptr pattern); + GroupExprBindingIterator(Memo& memo, GroupExpression *gexpr, std::shared_ptr pattern); bool HasNext() override; - std::shared_ptr Next() override; + std::shared_ptr Next() override; private: GroupExpression* gexpr_; @@ -78,8 +79,8 @@ class GroupExprBindingIterator : public BindingIterator { bool first_; bool has_next_; - std::shared_ptr current_binding_; - std::vector>> + std::shared_ptr current_binding_; + std::vector>> children_bindings_; std::vector children_bindings_pos_; }; diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index 914cc77ab27..9a64c6af4e7 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -13,6 +13,7 @@ #pragma once #include #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { @@ -33,8 +34,10 @@ class ChildPropertyDeriver : public OperatorVisitor { public: std::vector, std::vector>>> + GetProperties(GroupExpression *gexpr, - std::shared_ptr requirements, Memo *memo); + std::shared_ptr requirements, + Memo *memo); void Visit(const DummyScan *) override; void Visit(const PhysicalSeqScan *) override; @@ -74,8 +77,8 @@ class ChildPropertyDeriver : public OperatorVisitor { * @brief We need the memo and gexpr because some property may depend on * child's schema */ - Memo *memo_; GroupExpression *gexpr_; + Memo *memo_; }; } // namespace optimizer diff --git a/src/include/optimizer/cost_model/abstract_cost_model.h b/src/include/optimizer/cost_model/abstract_cost_model.h index 95a593f04d9..0a57be183d7 100644 --- a/src/include/optimizer/cost_model/abstract_cost_model.h +++ b/src/include/optimizer/cost_model/abstract_cost_model.h @@ -13,6 +13,7 @@ #pragma once #include "optimizer/operator_visitor.h" +#include "optimizer/operator_expression.h" namespace peloton { namespace optimizer { @@ -34,7 +35,8 @@ static constexpr double DEFAULT_OPERATOR_COST = 0.0025; class AbstractCostModel : public OperatorVisitor { public: - virtual double CalculateCost(GroupExpression *gexpr, Memo *memo, + virtual double CalculateCost(GroupExpression *gexpr, + Memo *memo, concurrency::TransactionContext *txn) = 0; }; diff --git a/src/include/optimizer/cost_model/default_cost_model.h b/src/include/optimizer/cost_model/default_cost_model.h index a92cb091db7..bbabd3d5a19 100644 --- a/src/include/optimizer/cost_model/default_cost_model.h +++ b/src/include/optimizer/cost_model/default_cost_model.h @@ -24,17 +24,18 @@ namespace peloton { namespace optimizer { class Memo; + // Derive cost for a physical group expression class DefaultCostModel : public AbstractCostModel { public: DefaultCostModel(){}; double CalculateCost(GroupExpression *gexpr, Memo *memo, - concurrency::TransactionContext *txn) { + concurrency::TransactionContext *txn) { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; } diff --git a/src/include/optimizer/cost_model/postgres_cost_model.h b/src/include/optimizer/cost_model/postgres_cost_model.h index 2632a247a39..d83e440ad2e 100644 --- a/src/include/optimizer/cost_model/postgres_cost_model.h +++ b/src/include/optimizer/cost_model/postgres_cost_model.h @@ -29,6 +29,7 @@ namespace peloton { namespace optimizer { class Memo; + // Derive cost for a physical group expression class PostgresCostModel : public AbstractCostModel { public: @@ -39,7 +40,7 @@ class PostgresCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; }; @@ -279,4 +280,4 @@ class PostgresCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/cost_model/trivial_cost_model.h b/src/include/optimizer/cost_model/trivial_cost_model.h index 2c5994ee728..1fc537b059f 100644 --- a/src/include/optimizer/cost_model/trivial_cost_model.h +++ b/src/include/optimizer/cost_model/trivial_cost_model.h @@ -32,6 +32,7 @@ namespace peloton { namespace optimizer { class Memo; + class TrivialCostModel : public AbstractCostModel { public: TrivialCostModel(){}; @@ -41,7 +42,7 @@ class TrivialCostModel : public AbstractCostModel { gexpr_ = gexpr; memo_ = memo; txn_ = txn; - gexpr_->Op().Accept(this); + gexpr_->Node()->Accept(this); return output_cost_; }; @@ -116,4 +117,4 @@ class TrivialCostModel : public AbstractCostModel { }; } // namespace optimizer -} // namespace peloton \ No newline at end of file +} // namespace peloton diff --git a/src/include/optimizer/group.h b/src/include/optimizer/group.h index a0606d1597c..e2f24ca953a 100644 --- a/src/include/optimizer/group.h +++ b/src/include/optimizer/group.h @@ -39,7 +39,8 @@ class Group : public Printable { // If the GroupExpression is generated by applying a // property enforcer, we add them to enforced_exprs_ // which will not be enumerated during OptimizeExpression - void AddExpression(std::shared_ptr expr, bool enforced); + void AddExpression(std::shared_ptr expr, + bool enforced); void RemoveLogicalExpression(size_t idx) { logical_expressions_.erase(logical_expressions_.begin() + idx); @@ -98,7 +99,10 @@ class Group : public Printable { // This is called in rewrite phase to erase the only logical expression in the // group inline void EraseLogicalExpression() { - PELOTON_ASSERT(logical_expressions_.size() == 1); + // During query rewriting (pre-optimizer), the rewriter can execute in a scenario + // where a group can have multiple logical expressions (due to AND/OR/= equivalence). + // TODO(): refine these assertions to distinguish between optimizer/rewrite stages + PELOTON_ASSERT(logical_expressions_.size() >= 1); PELOTON_ASSERT(physical_expressions_.size() == 0); logical_expressions_.clear(); } diff --git a/src/include/optimizer/group_expression.h b/src/include/optimizer/group_expression.h index 303ebaf036e..a39c9471e59 100644 --- a/src/include/optimizer/group_expression.h +++ b/src/include/optimizer/group_expression.h @@ -12,7 +12,7 @@ #pragma once -#include "optimizer/operator_node.h" +#include "optimizer/abstract_node.h" #include "optimizer/stats/stats.h" #include "optimizer/util.h" #include "optimizer/property_set.h" @@ -34,7 +34,7 @@ using GroupID = int32_t; //===--------------------------------------------------------------------===// class GroupExpression { public: - GroupExpression(Operator op, std::vector child_groups); + GroupExpression(std::shared_ptr node, std::vector child_groups); GroupID GetGroupID() const; @@ -46,7 +46,7 @@ class GroupExpression { GroupID GetChildGroupId(int child_idx) const; - Operator Op() const; + std::shared_ptr Node() const; double GetCost(std::shared_ptr& requirements) const; @@ -75,7 +75,7 @@ class GroupExpression { private: GroupID group_id; - Operator op; + std::shared_ptr node; std::vector child_groups; std::bitset(RuleType::NUM_RULES)> rule_mask_; bool stats_derived_; @@ -89,14 +89,3 @@ class GroupExpression { } // namespace optimizer } // namespace peloton - -namespace std { - -template <> -struct hash { - typedef peloton::optimizer::GroupExpression argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const &s) const { return s.Hash(); } -}; - -} // namespace std diff --git a/src/include/optimizer/input_column_deriver.h b/src/include/optimizer/input_column_deriver.h index ef66823bba0..4b218a9b99b 100644 --- a/src/include/optimizer/input_column_deriver.h +++ b/src/include/optimizer/input_column_deriver.h @@ -100,8 +100,8 @@ class InputColumnDeriver : public OperatorVisitor { * @brief Provide all tuple value expressions needed in the expression */ void ScanHelper(); - void AggregateHelper(const BaseOperatorNode *); - void JoinHelper(const BaseOperatorNode *op); + void AggregateHelper(const AbstractNode *); + void JoinHelper(const AbstractNode *op); /** * @brief Some operators, for example limit, directly pass down column diff --git a/src/include/optimizer/memo.h b/src/include/optimizer/memo.h index 951caa4c94d..4ad1633ae4d 100644 --- a/src/include/optimizer/memo.h +++ b/src/include/optimizer/memo.h @@ -27,8 +27,7 @@ struct GExprPtrHash { }; struct GExprPtrEq { - bool operator()(GroupExpression* const& t1, - GroupExpression* const& t2) const { + bool operator()(GroupExpression* const& t1, GroupExpression* const& t2) const { return *t1 == *t2; } }; @@ -48,11 +47,9 @@ class Memo { * target_group: an optional target group to insert expression into * return: existing expression if found. Otherwise, return the new expr */ - GroupExpression* InsertExpression(std::shared_ptr gexpr, - bool enforced); + GroupExpression* InsertExpression(std::shared_ptr gexpr, bool enforced); - GroupExpression* InsertExpression(std::shared_ptr gexpr, - GroupID target_group, bool enforced); + GroupExpression* InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced); std::vector>& Groups(); @@ -78,8 +75,11 @@ class Memo { // a new one, which reqires us to first remove the original gexpr from the // memo void EraseExpression(GroupID group_id) { - auto gexpr = groups_[group_id]->GetLogicalExpression(); - group_expressions_.erase(gexpr); + std::vector> gexprs = groups_[group_id]->GetLogicalExpressions(); + for (auto gexpr : gexprs) { + group_expressions_.erase(gexpr.get()); + } + groups_[group_id]->EraseLogicalExpression(); } @@ -87,8 +87,7 @@ class Memo { GroupID AddNewGroup(std::shared_ptr gexpr); // The group owns the group expressions, not the memo - std::unordered_set - group_expressions_; + std::unordered_set group_expressions_; std::vector> groups_; size_t rule_set_size_; }; diff --git a/src/include/optimizer/operator_expression.h b/src/include/optimizer/operator_expression.h index a37020915f4..2bdd4018b67 100644 --- a/src/include/optimizer/operator_expression.h +++ b/src/include/optimizer/operator_expression.h @@ -2,16 +2,17 @@ // // Peloton // -// op_expression.h +// operator_expression.h // -// Identification: src/include/optimizer/op_expression.h +// Identification: src/include/optimizer/operator_expression.h // -// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// Copyright (c) 2015-19, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// #pragma once +#include "optimizer/abstract_node_expression.h" #include "optimizer/operator_node.h" #include @@ -21,25 +22,25 @@ namespace peloton { namespace optimizer { //===--------------------------------------------------------------------===// -// Operator Expr +// Operator Expression //===--------------------------------------------------------------------===// -class OperatorExpression { +class OperatorExpression : public AbstractNodeExpression { public: - OperatorExpression(Operator op); + OperatorExpression(std::shared_ptr node); - void PushChild(std::shared_ptr op); + void PushChild(std::shared_ptr child); void PopChild(); - const std::vector> &Children() const; + const std::vector> &Children() const; - const Operator &Op() const; + const std::shared_ptr Node() const; const std::string GetInfo() const; private: - Operator op; - std::vector> children; + std::shared_ptr node; + std::vector> children; }; } // namespace optimizer diff --git a/src/include/optimizer/operator_node.h b/src/include/optimizer/operator_node.h index f870df330eb..f450a2cea74 100644 --- a/src/include/optimizer/operator_node.h +++ b/src/include/optimizer/operator_node.h @@ -12,6 +12,7 @@ #pragma once +#include "optimizer/abstract_node.h" #include "optimizer/property_set.h" #include "util/hash_util.h" @@ -21,103 +22,25 @@ namespace peloton { namespace optimizer { -enum class OpType { - Undefined = 0, - // Special match operators - Leaf, - // Logical ops - Get, - LogicalExternalFileGet, - LogicalQueryDerivedGet, - LogicalProjection, - LogicalFilter, - LogicalMarkJoin, - LogicalDependentJoin, - LogicalSingleJoin, - InnerJoin, - LeftJoin, - RightJoin, - OuterJoin, - SemiJoin, - LogicalAggregateAndGroupBy, - LogicalInsert, - LogicalInsertSelect, - LogicalDelete, - LogicalUpdate, - LogicalLimit, - LogicalDistinct, - LogicalExportExternalFile, - // Separate between logical and physical ops - LogicalPhysicalDelimiter, - // Physical ops - DummyScan, /* Dummy Physical Op for SELECT without FROM*/ - SeqScan, - IndexScan, - ExternalFileScan, - QueryDerivedScan, - OrderBy, - PhysicalLimit, - Distinct, - InnerNLJoin, - LeftNLJoin, - RightNLJoin, - OuterNLJoin, - InnerHashJoin, - LeftHashJoin, - RightHashJoin, - OuterHashJoin, - Insert, - InsertSelect, - Delete, - Update, - Aggregate, - HashGroupBy, - SortGroupBy, - ExportExternalFile, -}; - //===--------------------------------------------------------------------===// // Operator Node //===--------------------------------------------------------------------===// class OperatorVisitor; -struct BaseOperatorNode { - BaseOperatorNode() {} - - virtual ~BaseOperatorNode() {} - - virtual void Accept(OperatorVisitor *v) const = 0; - - virtual std::string GetName() const = 0; - - virtual OpType GetType() const = 0; - - virtual bool IsLogical() const = 0; - - virtual bool IsPhysical() const = 0; - - virtual std::vector RequiredInputProperties() const { - return {}; - } - - virtual hash_t Hash() const { - OpType t = GetType(); - return HashUtil::Hash(&t); - } +// TODO: should probably use a new AbstractBaseNode interface, not AbstractNode +template +struct OperatorNode : public AbstractNode { + OperatorNode() : AbstractNode(nullptr) {} - virtual bool operator==(const BaseOperatorNode &r) { - return GetType() == r.GetType(); - } -}; + virtual ~OperatorNode() {} -// Curiously recurring template pattern -template -struct OperatorNode : public BaseOperatorNode { void Accept(OperatorVisitor *v) const; std::string GetName() const { return name_; } - OpType GetType() const { return type_; } + OpType GetOpType() const { return op_type_; } + + ExpressionType GetExpType() const { return exp_type_; } bool IsLogical() const; @@ -125,47 +48,34 @@ struct OperatorNode : public BaseOperatorNode { static std::string name_; - static OpType type_; + static OpType op_type_; + + static ExpressionType exp_type_; }; -class Operator { +class Operator : public AbstractNode { public: Operator(); - Operator(BaseOperatorNode *node); + Operator(std::shared_ptr node); - // Calls corresponding visitor to node void Accept(OperatorVisitor *v) const; - // Return name of operator std::string GetName() const; - // Return operator type - OpType GetType() const; + OpType GetOpType() const; + + ExpressionType GetExpType() const; - // Operator contains Logical node bool IsLogical() const; - // Operator contains Physical node bool IsPhysical() const; hash_t Hash() const; bool operator==(const Operator &r); - // Operator contains physical or logical operator node bool IsDefined() const; - - template - const T *As() const { - if (node && typeid(*node) == typeid(T)) { - return (const T *)node.get(); - } - return nullptr; - } - - private: - std::shared_ptr node; }; } // namespace optimizer @@ -174,8 +84,8 @@ class Operator { namespace std { template <> -struct hash { - typedef peloton::optimizer::BaseOperatorNode argument_type; +struct hash { + typedef peloton::optimizer::AbstractNode argument_type; typedef std::size_t result_type; result_type operator()(argument_type const &s) const { return s.Hash(); } }; diff --git a/src/include/optimizer/operators.h b/src/include/optimizer/operators.h index c8a8483d242..5b3b9f04847 100644 --- a/src/include/optimizer/operators.h +++ b/src/include/optimizer/operators.h @@ -38,9 +38,9 @@ class PropertySort; //===--------------------------------------------------------------------===// // Leaf //===--------------------------------------------------------------------===// -class LeafOperator : OperatorNode { +class LeafOperator : public OperatorNode { public: - static Operator make(GroupID group); + static std::shared_ptr make(GroupID group); GroupID origin_group; }; @@ -50,12 +50,12 @@ class LeafOperator : OperatorNode { //===--------------------------------------------------------------------===// class LogicalGet : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id = 0, std::vector predicates = {}, std::shared_ptr table = nullptr, std::string alias = "", bool update = false); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -72,11 +72,11 @@ class LogicalGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalExternalFileGet : public OperatorNode { public: - static Operator make(oid_t get_id, ExternalFileFormat format, + static std::shared_ptr make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -94,13 +94,13 @@ class LogicalExternalFileGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalQueryDerivedGet : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id, std::string &alias, std::unordered_map> alias_to_expr_map); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -117,10 +117,10 @@ class LogicalQueryDerivedGet : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalFilter : public OperatorNode { public: - static Operator make(std::vector &filter); + static std::shared_ptr make(std::vector &filter); std::vector predicates; - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; }; @@ -130,7 +130,7 @@ class LogicalFilter : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalProjection : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> &elements); std::vector> expressions; }; @@ -140,11 +140,11 @@ class LogicalProjection : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDependentJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -156,11 +156,11 @@ class LogicalDependentJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalMarkJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -172,11 +172,11 @@ class LogicalMarkJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalSingleJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -188,11 +188,11 @@ class LogicalSingleJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalInnerJoin : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make(std::vector &conditions); + static std::shared_ptr make(std::vector &conditions); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -204,7 +204,7 @@ class LogicalInnerJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalLeftJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -214,7 +214,7 @@ class LogicalLeftJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalRightJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -224,7 +224,7 @@ class LogicalRightJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalOuterJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -234,7 +234,7 @@ class LogicalOuterJoin : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalSemiJoin : public OperatorNode { public: - static Operator make(expression::AbstractExpression *condition = nullptr); + static std::shared_ptr make(expression::AbstractExpression *condition = nullptr); std::shared_ptr join_predicate; }; @@ -245,16 +245,16 @@ class LogicalSemiJoin : public OperatorNode { class LogicalAggregateAndGroupBy : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); - static Operator make( + static std::shared_ptr make( std::vector> &columns); - static Operator make( + static std::shared_ptr make( std::vector> &columns, std::vector &having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; std::vector> columns; @@ -266,7 +266,7 @@ class LogicalAggregateAndGroupBy //===--------------------------------------------------------------------===// class LogicalInsert : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector *columns, const std::vector { class LogicalInsertSelect : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -291,7 +291,7 @@ class LogicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDistinct : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -299,7 +299,7 @@ class LogicalDistinct : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalLimit : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( int64_t offset, int64_t limit, std::vector &&sort_exprs, std::vector &&sort_ascending); @@ -318,7 +318,7 @@ class LogicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalDelete : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -329,7 +329,7 @@ class LogicalDelete : public OperatorNode { //===--------------------------------------------------------------------===// class LogicalUpdate : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector> *updates); @@ -343,10 +343,10 @@ class LogicalUpdate : public OperatorNode { class LogicalExportExternalFile : public OperatorNode { public: - static Operator make(ExternalFileFormat format, std::string file_name, + static std::shared_ptr make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -362,7 +362,7 @@ class LogicalExportExternalFile //===--------------------------------------------------------------------===// class DummyScan : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -370,13 +370,13 @@ class DummyScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSeqScan : public OperatorNode { public: - static Operator make(oid_t get_id, + static std::shared_ptr make(oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -393,7 +393,7 @@ class PhysicalSeqScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalIndexScan : public OperatorNode { public: - static Operator make(oid_t get_id, + static std::shared_ptr make(oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update, @@ -401,7 +401,7 @@ class PhysicalIndexScan : public OperatorNode { std::vector expr_type_list, std::vector value_list); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -426,11 +426,11 @@ class PhysicalIndexScan : public OperatorNode { //===--------------------------------------------------------------------===// class ExternalFileScan : public OperatorNode { public: - static Operator make(oid_t get_id, ExternalFileFormat format, + static std::shared_ptr make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -448,13 +448,13 @@ class ExternalFileScan : public OperatorNode { //===--------------------------------------------------------------------===// class QueryDerivedScan : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( oid_t get_id, std::string alias, std::unordered_map> alias_to_expr_map); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -471,7 +471,7 @@ class QueryDerivedScan : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalOrderBy : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; //===--------------------------------------------------------------------===// @@ -479,7 +479,7 @@ class PhysicalOrderBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalLimit : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( int64_t offset, int64_t limit, std::vector sort_columns, std::vector sort_ascending); @@ -498,12 +498,12 @@ class PhysicalLimit : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInnerNLJoin : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -519,7 +519,7 @@ class PhysicalInnerNLJoin : public OperatorNode { class PhysicalLeftNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -529,7 +529,7 @@ class PhysicalLeftNLJoin : public OperatorNode { class PhysicalRightNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -539,7 +539,7 @@ class PhysicalRightNLJoin : public OperatorNode { class PhysicalOuterNLJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -548,12 +548,12 @@ class PhysicalOuterNLJoin : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInnerHashJoin : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -569,7 +569,7 @@ class PhysicalInnerHashJoin : public OperatorNode { class PhysicalLeftHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -579,7 +579,7 @@ class PhysicalLeftHashJoin : public OperatorNode { class PhysicalRightHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -589,7 +589,7 @@ class PhysicalRightHashJoin : public OperatorNode { class PhysicalOuterHashJoin : public OperatorNode { public: std::shared_ptr join_predicate; - static Operator make( + static std::shared_ptr make( std::shared_ptr join_predicate); }; @@ -598,7 +598,7 @@ class PhysicalOuterHashJoin : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalInsert : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector *columns, const std::vector { class PhysicalInsertSelect : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; @@ -623,7 +623,7 @@ class PhysicalInsertSelect : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalDelete : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table); std::shared_ptr target_table; }; @@ -633,7 +633,7 @@ class PhysicalDelete : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalUpdate : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::shared_ptr target_table, const std::vector> *updates); @@ -647,10 +647,10 @@ class PhysicalUpdate : public OperatorNode { class PhysicalExportExternalFile : public OperatorNode { public: - static Operator make(ExternalFileFormat format, std::string file_name, + static std::shared_ptr make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; @@ -666,11 +666,11 @@ class PhysicalExportExternalFile //===--------------------------------------------------------------------===// class PhysicalHashGroupBy : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> columns, std::vector having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; std::vector> columns; @@ -682,11 +682,11 @@ class PhysicalHashGroupBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalSortGroupBy : public OperatorNode { public: - static Operator make( + static std::shared_ptr make( std::vector> columns, std::vector having); - bool operator==(const BaseOperatorNode &r) override; + bool operator==(const AbstractNode &r) override; hash_t Hash() const override; // TODO(boweic): use raw ptr std::vector> columns; @@ -698,12 +698,12 @@ class PhysicalSortGroupBy : public OperatorNode { //===--------------------------------------------------------------------===// class PhysicalAggregate : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; class PhysicalDistinct : public OperatorNode { public: - static Operator make(); + static std::shared_ptr make(); }; } // namespace optimizer diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index ebf82d625b4..93f1ddf1e76 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -89,7 +89,9 @@ class Optimizer : public AbstractOptimizer { /* For test purposes only */ std::shared_ptr TestInsertQueryTree( - parser::SQLStatement *tree, concurrency::TransactionContext *txn) { + parser::SQLStatement *tree, + concurrency::TransactionContext *txn) { + return InsertQueryTree(tree, txn); } /* For test purposes only */ diff --git a/src/include/optimizer/optimizer_metadata.h b/src/include/optimizer/optimizer_metadata.h index 3f33e3ee8b1..57dcb2ec7d8 100644 --- a/src/include/optimizer/optimizer_metadata.h +++ b/src/include/optimizer/optimizer_metadata.h @@ -19,6 +19,8 @@ #include "optimizer/rule.h" #include "settings/settings_manager.h" +#include + namespace peloton { namespace catalog { class Catalog; @@ -51,23 +53,22 @@ class OptimizerMetadata { } std::shared_ptr MakeGroupExpression( - std::shared_ptr expr) { + std::shared_ptr expr) { std::vector child_groups; for (auto &child : expr->Children()) { auto gexpr = MakeGroupExpression(child); memo.InsertExpression(gexpr, false); child_groups.push_back(gexpr->GetGroupID()); } - return std::make_shared(expr->Op(), - std::move(child_groups)); + return std::make_shared(expr->Node(), std::move(child_groups)); } - bool RecordTransformedExpression(std::shared_ptr expr, + bool RecordTransformedExpression(std::shared_ptr expr, std::shared_ptr &gexpr) { return RecordTransformedExpression(expr, gexpr, UNDEFINED_GROUP); } - bool RecordTransformedExpression(std::shared_ptr expr, + bool RecordTransformedExpression(std::shared_ptr expr, std::shared_ptr &gexpr, GroupID target_group) { gexpr = MakeGroupExpression(expr); @@ -75,7 +76,7 @@ class OptimizerMetadata { } // TODO(boweic): check if we really need to use shared_ptr - void ReplaceRewritedExpression(std::shared_ptr expr, + void ReplaceRewritedExpression(std::shared_ptr expr, GroupID target_group) { memo.EraseExpression(target_group); memo.InsertExpression(MakeGroupExpression(expr), target_group, false); diff --git a/src/include/optimizer/optimizer_task.h b/src/include/optimizer/optimizer_task.h index fb2edeaa5db..e02945e21b5 100644 --- a/src/include/optimizer/optimizer_task.h +++ b/src/include/optimizer/optimizer_task.h @@ -2,7 +2,7 @@ // // Peloton // -// rule.h +// optimizer_task.h // // Identification: src/include/optimizer/optimizer_task.h // @@ -14,6 +14,7 @@ #include #include +#include #include "expression/abstract_expression.h" #include "common/internal_types.h" @@ -32,6 +33,10 @@ class RuleSet; class Group; class GroupExpression; class OptimizerMetadata; + +enum class OpType; +class Operator; +class OperatorExpression; class PropertySet; enum class RewriteRuleSetName : uint32_t; using GroupID = int32_t; @@ -116,8 +121,7 @@ class OptimizeGroup : public OptimizerTask { */ class OptimizeExpression : public OptimizerTask { public: - OptimizeExpression(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeExpression(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_EXPR), group_expr_(group_expr) {} virtual void execute() override; @@ -148,8 +152,7 @@ class ExploreGroup : public OptimizerTask { */ class ExploreExpression : public OptimizerTask { public: - ExploreExpression(GroupExpression *group_expr, - std::shared_ptr context) + ExploreExpression(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::EXPLORE_EXPR), group_expr_(group_expr) {} virtual void execute() override; @@ -189,8 +192,7 @@ class ApplyRule : public OptimizerTask { */ class OptimizeInputs : public OptimizerTask { public: - OptimizeInputs(GroupExpression *group_expr, - std::shared_ptr context) + OptimizeInputs(GroupExpression *group_expr, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::OPTIMIZE_INPUTS), group_expr_(group_expr) {} @@ -222,9 +224,7 @@ class OptimizeInputs : public OptimizerTask { */ class DeriveStats : public OptimizerTask { public: - DeriveStats(GroupExpression *gexpr, - ExprSet required_cols, - std::shared_ptr context) + DeriveStats(GroupExpression *gexpr, ExprSet required_cols, std::shared_ptr context) : OptimizerTask(context, OptimizerTaskType::DERIVE_STATS), gexpr_(gexpr), required_cols_(required_cols) {} @@ -241,24 +241,52 @@ class DeriveStats : public OptimizerTask { ExprSet required_cols_; }; + +/** + * @brief Higher abstraction above TopDownRewrite and BottomUpRewrite that + * implements functionality similar and relied upon by both TopDownRewrite + * and BottomUpRewrite. + */ +class RewriteTask : public OptimizerTask { + public: + RewriteTask(OptimizerTaskType type, GroupID group_id, + std::shared_ptr context, + RewriteRuleSetName rule_set_name) + : OptimizerTask(context, type), + group_id_(group_id), + rule_set_name_(rule_set_name) {} + + virtual void execute() override { + LOG_ERROR("RewriteTask::execute invoked directly and not on derived"); + PELOTON_ASSERT(0); + }; + + protected: + std::set GetUniqueChildGroupIDs(); + bool OptimizeCurrentGroup(bool replace_on_match); + + GroupID group_id_; + RewriteRuleSetName rule_set_name_; +}; + /** * @brief Apply top-down rewrite pass, take in a rule set which must fulfill * that the lower level rewrite in the operator tree will not enable upper * level rewrite. An example is predicate push-down. We only push the predicates * from the upper level to the lower level. */ -class TopDownRewrite : public OptimizerTask { +class TopDownRewrite : public RewriteTask { public: TopDownRewrite(GroupID group_id, std::shared_ptr context, RewriteRuleSetName rule_set_name) - : OptimizerTask(context, OptimizerTaskType::TOP_DOWN_REWRITE), - group_id_(group_id), - rule_set_name_(rule_set_name) {} + : RewriteTask(OptimizerTaskType::TOP_DOWN_REWRITE, group_id, context, rule_set_name), + replace_on_transform_(true) {} + + void SetReplaceOnTransform(bool replace) { replace_on_transform_ = replace; } virtual void execute() override; private: - GroupID group_id_; - RewriteRuleSetName rule_set_name_; + bool replace_on_transform_; }; /** @@ -266,19 +294,16 @@ class TopDownRewrite : public OptimizerTask { * that the upper level rewrite in the operator tree will not enable lower * level rewrite. */ -class BottomUpRewrite : public OptimizerTask { +class BottomUpRewrite : public RewriteTask { public: BottomUpRewrite(GroupID group_id, std::shared_ptr context, RewriteRuleSetName rule_set_name, bool has_optimized_child) - : OptimizerTask(context, OptimizerTaskType::BOTTOM_UP_REWRITE), - group_id_(group_id), - rule_set_name_(rule_set_name), + : RewriteTask(OptimizerTaskType::BOTTOM_UP_REWRITE, group_id, context, rule_set_name), has_optimized_child_(has_optimized_child) {} + virtual void execute() override; private: - GroupID group_id_; - RewriteRuleSetName rule_set_name_; bool has_optimized_child_; }; } // namespace optimizer diff --git a/src/include/optimizer/optimizer_task_pool.h b/src/include/optimizer/optimizer_task_pool.h index a14789df64a..4165b865ac5 100644 --- a/src/include/optimizer/optimizer_task_pool.h +++ b/src/include/optimizer/optimizer_task_pool.h @@ -24,6 +24,7 @@ namespace optimizer { * is identical to a stack but we may need to implement a different data * structure for multi-threaded optimization */ + class OptimizerTaskPool { public: virtual std::unique_ptr Pop() = 0; diff --git a/src/include/optimizer/pattern.h b/src/include/optimizer/pattern.h index 67c52592889..3db6eebeb6c 100644 --- a/src/include/optimizer/pattern.h +++ b/src/include/optimizer/pattern.h @@ -24,16 +24,21 @@ class Pattern { public: Pattern(OpType op); + Pattern(ExpressionType exp_type); + void AddChild(std::shared_ptr child); const std::vector> &Children() const; inline size_t GetChildPatternsSize() const { return children.size(); } - OpType Type() const; + OpType GetOpType() const; + + ExpressionType GetExpType() const; private: - OpType _type; + OpType _op_type = OpType::Undefined; + ExpressionType _exp_type = ExpressionType::INVALID; std::vector> children; }; diff --git a/src/include/optimizer/rewriter.h b/src/include/optimizer/rewriter.h new file mode 100644 index 00000000000..271f6611032 --- /dev/null +++ b/src/include/optimizer/rewriter.h @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.h +// +// Identification: src/include/optimizer/rewriter.h +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "expression/abstract_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/absexpr_expression.h" + +namespace peloton { +namespace optimizer { + +class Rewriter { + + public: + /** + * Default constructor + */ + Rewriter(); + + /** + * Resets the internal state of the rewriter + */ + void Reset(); + + DISALLOW_COPY_AND_MOVE(Rewriter); + + /** + * Gets the OptimizerMetadata used by the rewriter + * @returns internal OptimizerMetadata + */ + OptimizerMetadata &GetMetadata() { return metadata_; } + + /** + * Rewrites an expression by applying applicable rules + * @param expr AbstractExpression to rewrite + * @returns rewriteen AbstractExpression + */ + expression::AbstractExpression* RewriteExpression(const expression::AbstractExpression *expr); + + private: + /** + * Creates an AbstractExpression from the Memo used internally + * @param root_group GroupID of the root group to begin building from + * @returns AbstractExpression from the stored groups + */ + expression::AbstractExpression* RebuildExpression(int root_group); + + /** + * Performs a single rewrite pass on the epxression + * @param root_group_id GroupID of the group to start rewriting from + */ + void RewriteLoop(int root_group_id); + + /** + * Converts AbstractExpression into internal rewriter representation + * @param expr expression to convert + * @returns shared pointer to rewriter internal representation + */ + std::shared_ptr ConvertToAbsExpr(const expression::AbstractExpression *expr); + + /** + * Records the original groups (subtrees) of the AbstractExpression. + * From the recorded information, it is possible to rebuild the expression. + * @param expr expression whose groups to record + * @returns GroupExpression representing the root of the expression + */ + std::shared_ptr RecordTreeGroups(const expression::AbstractExpression *expr); + + /** + * OptimizerMetadata that we leverage + */ + OptimizerMetadata metadata_; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 4ea78a630c6..eb1711342ed 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -26,9 +26,6 @@ class GroupExpression; #define PHYS_PROMISE 3 #define LOG_PROMISE 1 -/** - * @brief The base class of all rules - */ class Rule { public: virtual ~Rule(){}; @@ -74,7 +71,7 @@ class Rule { * * @return If the rule is applicable, return true, otherwise return false */ - virtual bool Check(std::shared_ptr expr, + virtual bool Check(std::shared_ptr expr, OptimizeContext *context) const = 0; /** @@ -85,8 +82,8 @@ class Rule { * @param context The current optimization context */ virtual void Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const = 0; inline RuleType GetType() { return type_; } @@ -113,7 +110,9 @@ struct RuleWithPromise { enum class RewriteRuleSetName : uint32_t { PREDICATE_PUSH_DOWN = 0, - UNNEST_SUBQUERY + UNNEST_SUBQUERY, + EQUIVALENT_TRANSFORM, + GENERIC_RULES }; /** @@ -146,9 +145,13 @@ class RuleSet { return rewrite_rules_map_[static_cast(set)]; } - std::unordered_map>> &GetRewriteRulesMap() { return rewrite_rules_map_; } + std::unordered_map>> &GetRewriteRulesMap() { + return rewrite_rules_map_; + } - std::vector> &GetPredicatePushDownRules() { return predicate_push_down_rules_; } + std::vector> &GetPredicatePushDownRules() { + return predicate_push_down_rules_; + } private: std::vector> transformation_rules_; diff --git a/src/include/optimizer/rule_impls.h b/src/include/optimizer/rule_impls.h index 57902e744a9..9296f2ce3a9 100644 --- a/src/include/optimizer/rule_impls.h +++ b/src/include/optimizer/rule_impls.h @@ -30,11 +30,11 @@ class InnerJoinCommutativity : public Rule { public: InnerJoinCommutativity(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -46,11 +46,11 @@ class InnerJoinAssociativity : public Rule { public: InnerJoinAssociativity(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -65,11 +65,11 @@ class GetToSeqScan : public Rule { public: GetToSeqScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -77,11 +77,11 @@ class LogicalExternalFileGetToPhysical : public Rule { public: LogicalExternalFileGetToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -93,11 +93,11 @@ class GetToDummyScan : public Rule { public: GetToDummyScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -108,11 +108,11 @@ class GetToIndexScan : public Rule { public: GetToIndexScan(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -123,11 +123,11 @@ class LogicalQueryDerivedGetToPhysical : public Rule { public: LogicalQueryDerivedGetToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -138,11 +138,11 @@ class LogicalDeleteToPhysical : public Rule { public: LogicalDeleteToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -153,11 +153,11 @@ class LogicalUpdateToPhysical : public Rule { public: LogicalUpdateToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -168,11 +168,11 @@ class LogicalInsertToPhysical : public Rule { public: LogicalInsertToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -183,11 +183,11 @@ class LogicalInsertSelectToPhysical : public Rule { public: LogicalInsertSelectToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -198,11 +198,11 @@ class LogicalGroupByToHashGroupBy : public Rule { public: LogicalGroupByToHashGroupBy(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -213,11 +213,11 @@ class LogicalAggregateToPhysical : public Rule { public: LogicalAggregateToPhysical(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -228,11 +228,11 @@ class InnerJoinToInnerNLJoin : public Rule { public: InnerJoinToInnerNLJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -243,11 +243,11 @@ class InnerJoinToInnerHashJoin : public Rule { public: InnerJoinToInnerHashJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -258,11 +258,11 @@ class ImplementDistinct : public Rule { public: ImplementDistinct(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -273,11 +273,11 @@ class ImplementLimit : public Rule { public: ImplementLimit(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -288,11 +288,11 @@ class LogicalExportToPhysicalExport : public Rule { public: LogicalExportToPhysicalExport(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -310,11 +310,11 @@ class PushFilterThroughJoin : public Rule { public: PushFilterThroughJoin(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -325,11 +325,11 @@ class CombineConsecutiveFilter : public Rule { public: CombineConsecutiveFilter(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -341,11 +341,11 @@ class PushFilterThroughAggregation : public Rule { public: PushFilterThroughAggregation(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; /** @@ -357,11 +357,11 @@ class EmbedFilterIntoGet : public Rule { public: EmbedFilterIntoGet(); - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -384,11 +384,11 @@ class MarkJoinToInnerJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; /////////////////////////////////////////////////////////////////////////////// @@ -400,11 +400,11 @@ class SingleJoinToInnerJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -417,11 +417,11 @@ class PullFilterThroughMarkJoin : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; @@ -434,11 +434,11 @@ class PullFilterThroughAggregation : public Rule { int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; - bool Check(std::shared_ptr plan, + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; - void Transform(std::shared_ptr input, - std::vector> &transformed, + void Transform(std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const override; }; } // namespace optimizer diff --git a/src/include/optimizer/rule_rewrite.h b/src/include/optimizer/rule_rewrite.h new file mode 100644 index 00000000000..0720fd0c875 --- /dev/null +++ b/src/include/optimizer/rule_rewrite.h @@ -0,0 +1,168 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rule_rewrite.h +// +// Identification: src/include/optimizer/rule_rewrite.h +// +// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "optimizer/rule.h" +#include "optimizer/absexpr_expression.h" + +#include + +namespace peloton { +namespace optimizer { + +using GroupExprTemplate = GroupExpression; +using OptimizeContext = OptimizeContext; + +/* Rules are applied from high to low priority */ +#define HIGH_PRIORITY 3 +#define MEDIUM_PRIORITY 2 +#define LOW_PRIORITY 1 + +/* + * Comparator Elimination: When two constant values are compared against + * each other (==, !=, >, <, >=, <=), the comparison expression gets rewritten + * to either TRUE or FALSE, depending on whether the constants agree with the + * comparison or not + * Examples: + * "1 == 2" ==> "FALSE" + * "3 <= 4" ==> "TRUE" + */ +class ComparatorElimination: public Rule { + public: + ComparatorElimination(RuleType rule, ExpressionType root); + + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Equivalent Transform: When a symmetric operator (==, !=, AND, OR) has two + * children, the comparison expression gets its arguments flipped. + * Examples: + * "T.X != 3" ==> "3 != T.X" + * "(T.X == 1) AND (T.Y == 2)" ==> "(T.Y == 2) AND (T.X == 1)" + */ +class EquivalentTransform: public Rule { + public: + EquivalentTransform(RuleType rule, ExpressionType root); + + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Tuple Value Equality with Two Constant Values: When the same tuple reference + * is checked against two distinct constant values, the expression is rewritten + * to FALSE + * Example: + * "(T.X == 3) AND (T.X == 4)" ==> "FALSE" + */ +class TVEqualityWithTwoCVTransform: public Rule { + public: + TVEqualityWithTwoCVTransform(); + + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Transitive Closure w/ Constants: When two tuple references are compared against each + * other and one of the tuple references is compared to a constant, the expression + * swaps out the doubled tuple reference for the constant. + * Example: + * "(T.X == Q.Y) AND (T.X == 6)" ==> "(6 == Q.Y) AND (T.X == 6)" + */ +class TransitiveClosureConstantTransform: public Rule { + public: + TransitiveClosureConstantTransform(); + + int Promise(GroupExpression *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * And Short Circuiting: Anything AND FALSE is rewritten to FALSE. + */ +class AndShortCircuit: public Rule { + public: + AndShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Or Short Circuiting: Anything OR TRUE is rewritten to TRUE. + */ +class OrShortCircuit: public Rule { + public: + OrShortCircuit(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Null Lookup on Not Null Column: Asking if a tuple reference is NULL is rewritten + * to FALSE only when the catalog says that that attribute has a non-NULL constraint. + * Example: + * "T.X IS NULL" ==> "FALSE" (assuming T.X is a non-NULL attribute) + */ +class NullLookupOnNotNullColumn: public Rule { + public: + NullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +/* + * Not Null Lookup on Not Null Column: Asking if a tuple reference is NOT NULL is rewritten + * to TRUE only when the catalog says that that attribute has a non-NULL constraint. + * Example: + * "T.X IS NOT NULL" ==> "TRUE" (assuming T.X is a non-NULL attribute) + */ +class NotNullLookupOnNotNullColumn: public Rule { + public: + NotNullLookupOnNotNullColumn(); + + int Promise(GroupExprTemplate *group_expr, OptimizeContext *context) const override; + bool Check(std::shared_ptr plan, OptimizeContext *context) const override; + void Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const override; +}; + +} // namespace optimizer +} // namespace peloton diff --git a/src/include/optimizer/stats/child_stats_deriver.h b/src/include/optimizer/stats/child_stats_deriver.h index d0c72f9bf9b..f4f3c05be20 100644 --- a/src/include/optimizer/stats/child_stats_deriver.h +++ b/src/include/optimizer/stats/child_stats_deriver.h @@ -23,13 +23,16 @@ namespace optimizer { class Memo; +class OperatorExpression; + // Derive child stats that has not yet been calculated for a logical group // expression class ChildStatsDeriver : public OperatorVisitor { public: std::vector DeriveInputStats( GroupExpression *gexpr, - ExprSet required_cols, Memo *memo); + ExprSet required_cols, + Memo *memo); void Visit(const LogicalQueryDerivedGet *) override; void Visit(const LogicalInnerJoin *) override; diff --git a/src/include/optimizer/stats/stats_calculator.h b/src/include/optimizer/stats/stats_calculator.h index 9637db2f224..79d1988de95 100644 --- a/src/include/optimizer/stats/stats_calculator.h +++ b/src/include/optimizer/stats/stats_calculator.h @@ -19,6 +19,7 @@ namespace optimizer { class Memo; class TableStats; +class OperatorExpression; /** * @brief Derive stats for the root group using a group expression's children's @@ -26,8 +27,10 @@ class TableStats; */ class StatsCalculator : public OperatorVisitor { public: - void CalculateStats(GroupExpression *gexpr, ExprSet required_cols, - Memo *memo, concurrency::TransactionContext* txn); + void CalculateStats(GroupExpression *gexpr, + ExprSet required_cols, + Memo *memo, + concurrency::TransactionContext* txn); void Visit(const LogicalGet *) override; void Visit(const LogicalQueryDerivedGet *) override; diff --git a/src/include/traffic_cop/traffic_cop.h b/src/include/traffic_cop/traffic_cop.h index e324b87fe82..8870591a4df 100644 --- a/src/include/traffic_cop/traffic_cop.h +++ b/src/include/traffic_cop/traffic_cop.h @@ -25,6 +25,7 @@ #include "common/statement.h" #include "executor/plan_executor.h" #include "optimizer/abstract_optimizer.h" +#include "optimizer/rewriter.h" #include "parser/sql_statement.h" #include "type/type.h" @@ -196,6 +197,8 @@ class TrafficCop { // still a HACK void GetTableColumns(parser::TableRef *from_table, std::vector &target_tables); + + optimizer::Rewriter rewriter_; }; } // namespace tcop diff --git a/src/optimizer/absexpr_expression.cpp b/src/optimizer/absexpr_expression.cpp new file mode 100644 index 00000000000..30122603ffe --- /dev/null +++ b/src/optimizer/absexpr_expression.cpp @@ -0,0 +1,146 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_expression.cpp +// +// Identification: src/optimizer/absexpr_expression.cpp +// +//===----------------------------------------------------------------------===// + +#include "optimizer/absexpr_expression.h" +#include "expression/operator_expression.h" +#include "expression/aggregate_expression.h" +#include "expression/star_expression.h" + +#include +#include + +namespace peloton { +namespace optimizer { + +expression::AbstractExpression *AbsExprNode::CopyWithChildren(std::vector children) { + // Pre-compute left and right + expression::AbstractExpression *left = nullptr; + expression::AbstractExpression *right = nullptr; + if (children.size() >= 2) { + left = children[0]; + right = children[1]; + } else if (children.size() == 1) { + left = children[0]; + } + + auto type = GetExpType(); + switch (type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_NOTEQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + case ExpressionType::COMPARE_LIKE: + case ExpressionType::COMPARE_NOTLIKE: + case ExpressionType::COMPARE_IN: + case ExpressionType::COMPARE_DISTINCT_FROM: { + // Create new expression with 2 new children of the same type + return new expression::ComparisonExpression(type, left, right); + } + + case ExpressionType::CONJUNCTION_AND: + case ExpressionType::CONJUNCTION_OR: { + // Create new expression with the new children + return new expression::ConjunctionExpression(type, left, right); + } + + case ExpressionType::OPERATOR_PLUS: + case ExpressionType::OPERATOR_MINUS: + case ExpressionType::OPERATOR_MULTIPLY: + case ExpressionType::OPERATOR_DIVIDE: + case ExpressionType::OPERATOR_CONCAT: + case ExpressionType::OPERATOR_MOD: + case ExpressionType::OPERATOR_NOT: + case ExpressionType::OPERATOR_IS_NULL: + case ExpressionType::OPERATOR_IS_NOT_NULL: + case ExpressionType::OPERATOR_EXISTS: { + // Create new expression, preserving return_value_type_ + type::TypeId ret = expr->GetValueType(); + return new expression::OperatorExpression(type, ret, left, right); + } + + case ExpressionType::OPERATOR_UNARY_MINUS: { + PELOTON_ASSERT(children.size() == 1); + return new expression::OperatorUnaryMinusExpression(left); + } + + case ExpressionType::STAR: + case ExpressionType::VALUE_CONSTANT: + case ExpressionType::VALUE_PARAMETER: + case ExpressionType::VALUE_TUPLE: { + PELOTON_ASSERT(children.size() == 0); + return expr->Copy(); + } + + case ExpressionType::AGGREGATE_COUNT: + case ExpressionType::AGGREGATE_COUNT_STAR: + case ExpressionType::AGGREGATE_SUM: + case ExpressionType::AGGREGATE_MIN: + case ExpressionType::AGGREGATE_MAX: + case ExpressionType::AGGREGATE_AVG: { + // We should not be changing # of children of AggregateExpression + PELOTON_ASSERT(expr->GetChildrenSize() == children.size()); + + // Unfortunately, the aggregate_expression (also applies to function) + // may already have extra state information created due to the binder. + // Under Peloton's design, we decide to just copy() the node and then + // install the child. + auto expr_copy = expr->Copy(); + if (children.size() == 1) { + // If we updated the child, install the child + expr_copy->SetChild(0, children[0]); + } + + return expr_copy; + } + + case ExpressionType::FUNCTION: { + // We really should not be modifying # children of Function + PELOTON_ASSERT(children.size() == expr->GetChildrenSize()); + auto copy = expr->Copy(); + + size_t num_child = children.size(); + for (size_t i = 0; i < num_child; i++) { + copy->SetChild(i, children[i]); + } + return copy; + } + + // Rewriting for these 2 uses special matching patterns. + // As such, when building as an output, we just copy directly. + case ExpressionType::ROW_SUBQUERY: + case ExpressionType::OPERATOR_CASE_EXPR: { + PELOTON_ASSERT(children.size() == 0); + return expr->Copy(); + } + + // These ExpressionTypes are never instantiated as a type + case ExpressionType::PLACEHOLDER: + case ExpressionType::COLUMN_REF: + case ExpressionType::FUNCTION_REF: + case ExpressionType::TABLE_REF: + case ExpressionType::SELECT_SUBQUERY: + case ExpressionType::VALUE_TUPLE_ADDRESS: + case ExpressionType::VALUE_NULL: + case ExpressionType::VALUE_VECTOR: + case ExpressionType::VALUE_SCALAR: + case ExpressionType::HASH_RANGE: + case ExpressionType::OPERATOR_CAST: + default: { + int type = static_cast(GetExpType()); + LOG_ERROR("Unimplemented Rebuild() for %d found", type); + return expr->Copy(); + } + } +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/binding.cpp b/src/optimizer/binding.cpp index 9651ce8102c..40786696326 100644 --- a/src/optimizer/binding.cpp +++ b/src/optimizer/binding.cpp @@ -12,9 +12,15 @@ #include "optimizer/binding.h" +#include + #include "common/logger.h" #include "optimizer/operator_visitor.h" #include "optimizer/optimizer.h" +#include "optimizer/absexpr_expression.h" +#include "expression/group_marker_expression.h" +#include "expression/abstract_expression.h" +#include "expression/tuple_value_expression.h" namespace peloton { namespace optimizer { @@ -27,15 +33,18 @@ GroupBindingIterator::GroupBindingIterator(Memo &memo, GroupID id, : BindingIterator(memo), group_id_(id), pattern_(pattern), - target_group_(memo_.GetGroupByID(id)), + target_group_(this->memo_.GetGroupByID(id)), num_group_items_(target_group_->GetLogicalExpressions().size()), current_item_index_(0) { LOG_TRACE("Attempting to bind on group %d", id); } bool GroupBindingIterator::HasNext() { - LOG_TRACE("HasNext"); - if (pattern_->Type() == OpType::Leaf) { + LOG_TRACE("HasNextBinding"); + + // TODO(): Can we do this generic pattern any better? + if ((pattern_->GetOpType() == OpType::Leaf && pattern_->GetExpType() == ExpressionType::INVALID) || + (pattern_->GetOpType() == OpType::Undefined && pattern_->GetExpType() == ExpressionType::GROUP_MARKER)) { return current_item_index_ == 0; } @@ -51,7 +60,7 @@ bool GroupBindingIterator::HasNext() { // Keep checking item iterators until we find a match while (current_item_index_ < num_group_items_) { current_iterator_.reset(new GroupExprBindingIterator( - memo_, + this->memo_, target_group_->GetLogicalExpressions()[current_item_index_].get(), pattern_)); @@ -67,11 +76,19 @@ bool GroupBindingIterator::HasNext() { return current_iterator_ != nullptr; } -std::shared_ptr GroupBindingIterator::Next() { - if (pattern_->Type() == OpType::Leaf) { +std::shared_ptr GroupBindingIterator::Next() { + if (pattern_->GetOpType() == OpType::Leaf && pattern_->GetExpType() == ExpressionType::INVALID) { current_item_index_ = num_group_items_; return std::make_shared(LeafOperator::make(group_id_)); } + + if (pattern_->GetOpType() == OpType::Undefined && pattern_->GetExpType() == ExpressionType::GROUP_MARKER) { + current_item_index_ = num_group_items_; + + auto expr = std::make_shared(group_id_); + return std::make_shared(std::make_shared(expr)); + } + return current_iterator_->Next(); } @@ -79,20 +96,27 @@ std::shared_ptr GroupBindingIterator::Next() { // Item Binding Iterator //===--------------------------------------------------------------------===// GroupExprBindingIterator::GroupExprBindingIterator( - Memo &memo, GroupExpression *gexpr, std::shared_ptr pattern) + Memo &memo, GroupExpression *gexpr, std::shared_ptr pattern) : BindingIterator(memo), gexpr_(gexpr), pattern_(pattern), first_(true), - has_next_(false), - current_binding_(std::make_shared(gexpr->Op())) { - if (gexpr->Op().GetType() != pattern->Type()) { + has_next_(false) { + + if (gexpr->Node()->GetOpType() != pattern->GetOpType() || + gexpr->Node()->GetExpType() != pattern->GetExpType()) { return; } + // Create right type of AbstractNodeExpression depending on type + if (gexpr->Node()->GetOpType() != OpType::Undefined) { + current_binding_ = std::make_shared(gexpr->Node()); + } else { + current_binding_ = std::make_shared(gexpr->Node()); + } + const std::vector &child_groups = gexpr->GetChildGroupIDs(); - const std::vector> &child_patterns = - pattern->Children(); + const std::vector> &child_patterns = pattern->Children(); if (child_groups.size() != child_patterns.size()) { return; @@ -100,16 +124,16 @@ GroupExprBindingIterator::GroupExprBindingIterator( LOG_TRACE( "Attempting to bind on group %d with expression of type %s, children " "size %lu", - gexpr->GetGroupID(), gexpr->Op().GetName().c_str(), child_groups.size()); + gexpr->GetGroupID(), gexpr->Node()->GetName().c_str(), child_groups.size()); // Find all bindings for children children_bindings_.resize(child_groups.size(), {}); children_bindings_pos_.resize(child_groups.size(), 0); for (size_t i = 0; i < child_groups.size(); ++i) { // Try to find a match in the given group - std::vector> &child_bindings = + std::vector> &child_bindings = children_bindings_[i]; - GroupBindingIterator iterator(memo_, child_groups[i], child_patterns[i]); + GroupBindingIterator iterator(this->memo_, child_groups[i], child_patterns[i]); // Get all bindings while (iterator.HasNext()) { @@ -137,7 +161,7 @@ bool GroupExprBindingIterator::HasNext() { // The first child to be modified int first_modified_idx = children_bindings_pos_.size() - 1; for (; first_modified_idx >= 0; --first_modified_idx) { - const std::vector> &child_binding = + const std::vector> &child_binding = children_bindings_[first_modified_idx]; // Try to increment idx from the back @@ -154,16 +178,15 @@ bool GroupExprBindingIterator::HasNext() { has_next_ = false; } else { // Pop all updated childrens - for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); - idx++) { + for (size_t idx = first_modified_idx; idx < children_bindings_pos_.size(); idx++) { current_binding_->PopChild(); } // Add new children to end for (size_t offset = first_modified_idx; offset < children_bindings_pos_.size(); ++offset) { - const std::vector> &child_binding = + const std::vector> &child_binding = children_bindings_[offset]; - std::shared_ptr binding = + std::shared_ptr binding = child_binding[children_bindings_pos_[offset]]; current_binding_->PushChild(binding); } @@ -172,9 +195,10 @@ bool GroupExprBindingIterator::HasNext() { return has_next_; } -std::shared_ptr GroupExprBindingIterator::Next() { +std::shared_ptr GroupExprBindingIterator::Next() { return current_binding_; } + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index b432067fae1..c41562e07a7 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -38,7 +38,7 @@ ChildPropertyDeriver::GetProperties(GroupExpression *gexpr, output_.clear(); memo_ = memo; gexpr_ = gexpr; - gexpr->Op().Accept(this); + gexpr->Node()->Accept(this); return move(output_); } diff --git a/src/optimizer/group.cpp b/src/optimizer/group.cpp index 673a7a1b8bd..55f6093be75 100644 --- a/src/optimizer/group.cpp +++ b/src/optimizer/group.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "optimizer/group.h" +#include "optimizer/operator_expression.h" +#include "optimizer/absexpr_expression.h" #include "common/logger.h" @@ -27,11 +29,18 @@ Group::Group(GroupID id, std::unordered_set table_aliases) void Group::AddExpression(std::shared_ptr expr, bool enforced) { + + // Additional assertion checks for AddExpression() with AST rewriting + if (expr->Node()->GetExpType() != ExpressionType::INVALID) { + PELOTON_ASSERT(!enforced); + PELOTON_ASSERT(!expr->Node()->IsPhysical()); + } + // Do duplicate detection expr->SetGroupID(id_); if (enforced) enforced_exprs_.push_back(expr); - else if (expr->Op().IsPhysical()) + else if (expr->Node()->IsPhysical()) physical_expressions_.push_back(expr); else logical_expressions_.push_back(expr); @@ -39,8 +48,9 @@ void Group::AddExpression(std::shared_ptr expr, bool Group::SetExpressionCost(GroupExpression *expr, double cost, std::shared_ptr &properties) { + LOG_TRACE("Adding expression cost on group %d with op %s, req %s", - expr->GetGroupID(), expr->Op().GetName().c_str(), + expr->GetGroupID(), expr->Node()->GetName().c_str(), properties->ToString().c_str()); auto it = lowest_cost_expressions_.find(properties); if (it == lowest_cost_expressions_.end() || std::get<0>(it->second) > cost) { @@ -51,8 +61,10 @@ bool Group::SetExpressionCost(GroupExpression *expr, double cost, } return false; } + GroupExpression *Group::GetBestExpression( std::shared_ptr &properties) { + auto it = lowest_cost_expressions_.find(properties); if (it != lowest_cost_expressions_.end()) { return std::get<1>(it->second); @@ -64,6 +76,7 @@ GroupExpression *Group::GetBestExpression( bool Group::HasExpressions( const std::shared_ptr &properties) const { + const auto &it = lowest_cost_expressions_.find(properties); return (it != lowest_cost_expressions_.end()); } @@ -76,68 +89,68 @@ std::shared_ptr Group::GetStats(std::string column_name) { } const std::string Group::GetInfo(int num_indent) const { - std::ostringstream os; - os << StringUtil::Indent(num_indent) - << "GroupID: " << GetID() << std::endl; - - if (logical_expressions_.size() > 0) - os << StringUtil::Indent(num_indent + 2) - << "logical_expressions_: \n"; - - for (auto expr : logical_expressions_) { - os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; - const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); - if (ChildGroupIDs.size() > 0) { - os << StringUtil::Indent(num_indent + 6) - << "ChildGroupIDs: "; - for (auto childGroupID : ChildGroupIDs) - os << childGroupID << " "; - os << std::endl; - } - } - - if (physical_expressions_.size() > 0) - os << StringUtil::Indent(num_indent + 2) - << "physical_expressions_: \n"; - for (auto expr : physical_expressions_) { - os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; - const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); - if (ChildGroupIDs.size() > 0) { - os << StringUtil::Indent(num_indent + 6) - << "ChildGroupIDs: "; - for (auto childGroupID : ChildGroupIDs) - os << childGroupID << " "; - os << std::endl; - } - - } - - if (enforced_exprs_.size() > 0) - os << StringUtil::Indent(num_indent + 2) - << "enforced_exprs_: \n"; - for (auto expr : enforced_exprs_) { - os << StringUtil::Indent(num_indent + 4) - << expr->Op().GetName() << std::endl; - const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); - if (ChildGroupIDs.size() > 0) { - os << StringUtil::Indent(num_indent + 6) - << "ChildGroupIDs: \n"; - for (auto childGroupID : ChildGroupIDs) { - os << childGroupID << " "; - } - os << std::endl; - } - } - - return os.str(); + std::ostringstream os; + os << StringUtil::Indent(num_indent) + << "GroupID: " << GetID() << std::endl; + + if (logical_expressions_.size() > 0) + os << StringUtil::Indent(num_indent + 2) + << "logical_expressions_: \n"; + + for (auto expr : logical_expressions_) { + os << StringUtil::Indent(num_indent + 4) + << expr->Node()->GetName() << std::endl; + const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); + if (ChildGroupIDs.size() > 0) { + os << StringUtil::Indent(num_indent + 6) + << "ChildGroupIDs: "; + for (auto childGroupID : ChildGroupIDs) + os << childGroupID << " "; + os << std::endl; + } + } + + if (physical_expressions_.size() > 0) + os << StringUtil::Indent(num_indent + 2) + << "physical_expressions_: \n"; + for (auto expr : physical_expressions_) { + os << StringUtil::Indent(num_indent + 4) + << expr->Node()->GetName() << std::endl; + const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); + if (ChildGroupIDs.size() > 0) { + os << StringUtil::Indent(num_indent + 6) + << "ChildGroupIDs: "; + for (auto childGroupID : ChildGroupIDs) + os << childGroupID << " "; + os << std::endl; + } + + } + + if (enforced_exprs_.size() > 0) + os << StringUtil::Indent(num_indent + 2) + << "enforced_exprs_: \n"; + for (auto expr : enforced_exprs_) { + os << StringUtil::Indent(num_indent + 4) + << expr->Node()->GetName() << std::endl; + const std::vector ChildGroupIDs = expr->GetChildGroupIDs(); + if (ChildGroupIDs.size() > 0) { + os << StringUtil::Indent(num_indent + 6) + << "ChildGroupIDs: \n"; + for (auto childGroupID : ChildGroupIDs) { + os << childGroupID << " "; + } + os << std::endl; + } + } + + return os.str(); } const std::string Group::GetInfo() const { - std::ostringstream os; - os << GetInfo(0); - return os.str(); + std::ostringstream os; + os << GetInfo(0); + return os.str(); } diff --git a/src/optimizer/group_expression.cpp b/src/optimizer/group_expression.cpp index 498c949b583..4b11242c30a 100644 --- a/src/optimizer/group_expression.cpp +++ b/src/optimizer/group_expression.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "common/internal_types.h" +#include "optimizer/absexpr_expression.h" #include "optimizer/group_expression.h" #include "optimizer/group.h" #include "optimizer/rule.h" @@ -21,9 +22,10 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Group Expression //===--------------------------------------------------------------------===// -GroupExpression::GroupExpression(Operator op, std::vector child_groups) +GroupExpression::GroupExpression(std::shared_ptr node, + std::vector child_groups) : group_id(UNDEFINED_GROUP), - op(op), + node(node), child_groups(child_groups), stats_derived_(false) {} @@ -43,7 +45,9 @@ GroupID GroupExpression::GetChildGroupId(int child_idx) const { return child_groups[child_idx]; } -Operator GroupExpression::Op() const { return op; } +std::shared_ptr GroupExpression::Node() const { + return std::shared_ptr(node); +} double GroupExpression::GetCost( std::shared_ptr &requirements) const { @@ -74,7 +78,7 @@ void GroupExpression::SetLocalHashTable( } hash_t GroupExpression::Hash() const { - size_t hash = op.Hash(); + size_t hash = node->Hash(); for (size_t i = 0; i < child_groups.size(); ++i) { hash = HashUtil::CombineHashes(hash, @@ -85,7 +89,7 @@ hash_t GroupExpression::Hash() const { } bool GroupExpression::operator==(const GroupExpression &r) { - return (op == r.Op()) && (child_groups == r.child_groups); + return (*node == *r.Node()) && (child_groups == r.child_groups); } void GroupExpression::SetRuleExplored(Rule *rule) { diff --git a/src/optimizer/input_column_deriver.cpp b/src/optimizer/input_column_deriver.cpp index fdffb7e79a6..c2f94cc2d3d 100644 --- a/src/optimizer/input_column_deriver.cpp +++ b/src/optimizer/input_column_deriver.cpp @@ -43,7 +43,7 @@ InputColumnDeriver::DeriveInputColumns( gexpr_ = gexpr; required_cols_ = move(required_cols); memo_ = memo; - gexpr->Op().Accept(this); + gexpr->Node()->Accept(this); return move(output_input_cols_); } @@ -200,7 +200,7 @@ void InputColumnDeriver::ScanHelper() { output_cols, {}}; } -void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) { +void InputColumnDeriver::AggregateHelper(const AbstractNode *op) { ExprSet input_cols_set; ExprMap output_cols_map; oid_t output_col_idx = 0; @@ -232,11 +232,11 @@ void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) { // TODO(boweic): do not use shared_ptr vector> groupby_cols; vector having_exprs; - if (op->GetType() == OpType::HashGroupBy) { + if (op->GetOpType() == OpType::HashGroupBy) { auto groupby = reinterpret_cast(op); groupby_cols = groupby->columns; having_exprs = groupby->having; - } else if (op->GetType() == OpType::SortGroupBy) { + } else if (op->GetOpType() == OpType::SortGroupBy) { auto groupby = reinterpret_cast(op); groupby_cols = groupby->columns; having_exprs = groupby->having; @@ -269,17 +269,17 @@ void InputColumnDeriver::AggregateHelper(const BaseOperatorNode *op) { output_cols, {input_cols}}; } -void InputColumnDeriver::JoinHelper(const BaseOperatorNode *op) { +void InputColumnDeriver::JoinHelper(const AbstractNode *op) { const vector *join_conds = nullptr; const vector> *left_keys = nullptr; const vector> *right_keys = nullptr; - if (op->GetType() == OpType::InnerHashJoin) { + if (op->GetOpType() == OpType::InnerHashJoin) { auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); right_keys = &(join_op->right_keys); - } else if (op->GetType() == OpType::InnerNLJoin) { + } else if (op->GetOpType() == OpType::InnerNLJoin) { auto join_op = reinterpret_cast(op); join_conds = &(join_op->join_predicates); left_keys = &(join_op->left_keys); diff --git a/src/optimizer/memo.cpp b/src/optimizer/memo.cpp index ca68a52c1d0..9f649cc00ae 100644 --- a/src/optimizer/memo.cpp +++ b/src/optimizer/memo.cpp @@ -14,6 +14,8 @@ #include "optimizer/memo.h" #include "optimizer/operators.h" #include "optimizer/stats/stats_calculator.h" +#include "optimizer/absexpr_expression.h" +#include "expression/group_marker_expression.h" namespace peloton { namespace optimizer { @@ -23,25 +25,36 @@ namespace optimizer { //===--------------------------------------------------------------------===// Memo::Memo() {} -GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, - bool enforced) { - return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); -} - +//===--------------------------------------------------------------------===// +// Memo remaining interface functions +//===--------------------------------------------------------------------===// GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, GroupID target_group, bool enforced) { + // If leaf, then just return - if (gexpr->Op().GetType() == OpType::Leaf) { - const LeafOperator *leaf = gexpr->Op().As(); + if (gexpr->Node()->GetOpType() == OpType::Leaf && + gexpr->Node()->GetExpType() == ExpressionType::INVALID) { + const LeafOperator *leaf = gexpr->Node()->As(); PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == leaf->origin_group); gexpr->SetGroupID(leaf->origin_group); return nullptr; } - // Lookup in hash table - auto it = group_expressions_.find(gexpr.get()); + if (gexpr->Node()->GetOpType() == OpType::Undefined && + gexpr->Node()->GetExpType() == ExpressionType::GROUP_MARKER) { + + auto abs_node = std::dynamic_pointer_cast(gexpr->Node()); + PELOTON_ASSERT(abs_node != nullptr); + + auto gm_expr = std::dynamic_pointer_cast(abs_node->GetExpr()); + PELOTON_ASSERT(gm_expr != nullptr); + PELOTON_ASSERT(target_group == UNDEFINED_GROUP || target_group == gm_expr->GetGroupID()); + gexpr->SetGroupID(gm_expr->GetGroupID()); + return nullptr; + } + auto it = group_expressions_.find(gexpr.get()); if (it != group_expressions_.end()) { gexpr->SetGroupID((*it)->GetGroupID()); return *it; @@ -55,12 +68,19 @@ GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, } else { group_id = target_group; } + Group *group = GetGroupByID(group_id); group->AddExpression(gexpr, enforced); return gexpr.get(); } } +GroupExpression *Memo::InsertExpression(std::shared_ptr gexpr, + bool enforced) { + + return InsertExpression(gexpr, UNDEFINED_GROUP, enforced); +} + std::vector> &Memo::Groups() { return groups_; } @@ -86,21 +106,24 @@ const std::string Memo::GetInfo() const { return os.str(); } - +//===--------------------------------------------------------------------===// +// Memo::AddNewGroup +//===--------------------------------------------------------------------===// GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { GroupID new_group_id = groups_.size(); // Find out the table alias that this group represents std::unordered_set table_aliases; - auto op_type = gexpr->Op().GetType(); + auto op_type = gexpr->Node()->GetOpType(); + if (op_type == OpType::Get) { // For base group, the table alias can get directly from logical get - const LogicalGet *logical_get = gexpr->Op().As(); + const LogicalGet *logical_get = gexpr->Node()->As(); table_aliases.insert(logical_get->table_alias); } else if (op_type == OpType::LogicalQueryDerivedGet) { const LogicalQueryDerivedGet *query_get = - gexpr->Op().As(); + gexpr->Node()->As(); table_aliases.insert(query_get->table_alias); - } else { + } else if (op_type != OpType::Undefined) { // For other groups, need to aggregate the table alias from children for (auto child_group_id : gexpr->GetChildGroupIDs()) { Group *child_group = GetGroupByID(child_group_id); @@ -110,8 +133,7 @@ GroupID Memo::AddNewGroup(std::shared_ptr gexpr) { } } - groups_.emplace_back( - new Group(new_group_id, std::move(table_aliases))); + groups_.emplace_back(new Group(new_group_id, std::move(table_aliases))); return new_group_id; } diff --git a/src/optimizer/operator_expression.cpp b/src/optimizer/operator_expression.cpp index 9ba9d2de706..ba96a124e27 100644 --- a/src/optimizer/operator_expression.cpp +++ b/src/optimizer/operator_expression.cpp @@ -2,11 +2,11 @@ // // Peloton // -// op_expression.cpp +// operator_expression.cpp // -// Identification: src/optimizer/op_expression.cpp +// Identification: src/optimizer/operator_expression.cpp // -// Copyright (c) 2015-16, Carnegie Mellon University Database Group +// Copyright (c) 2015-19, Carnegie Mellon University Database Group // //===----------------------------------------------------------------------===// @@ -20,26 +20,26 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Operator Expression //===--------------------------------------------------------------------===// -OperatorExpression::OperatorExpression(Operator op) : op(op) {} +OperatorExpression::OperatorExpression(std::shared_ptr node) : node(node) {} -void OperatorExpression::PushChild(std::shared_ptr op) { - children.push_back(op); +void OperatorExpression::PushChild(std::shared_ptr node) { + children.push_back(node); } void OperatorExpression::PopChild() { children.pop_back(); } -const std::vector> +const std::vector> &OperatorExpression::Children() const { return children; } -const Operator &OperatorExpression::Op() const { return op; } +const std::shared_ptr OperatorExpression::Node() const { return node; } const std::string OperatorExpression::GetInfo() const { std::string info = "{"; { info += "\"Op\":"; - info += "\"" + op.GetName() + "\","; + info += "\"" + node->GetName() + "\","; if (!children.empty()) { info += "\"Children\":["; { diff --git a/src/optimizer/operator_node.cpp b/src/optimizer/operator_node.cpp index e262792e774..cb80ab5bf39 100644 --- a/src/optimizer/operator_node.cpp +++ b/src/optimizer/operator_node.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "optimizer/abstract_node.h" #include "optimizer/operator_node.h" namespace peloton { @@ -18,9 +19,9 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Operator //===--------------------------------------------------------------------===// -Operator::Operator() : node(nullptr) {} +Operator::Operator() : AbstractNode(nullptr) {} -Operator::Operator(BaseOperatorNode *node) : node(node) {} +Operator::Operator(std::shared_ptr node) : AbstractNode(node) {} void Operator::Accept(OperatorVisitor *v) const { node->Accept(v); } @@ -31,13 +32,20 @@ std::string Operator::GetName() const { return "Undefined"; } -OpType Operator::GetType() const { +OpType Operator::GetOpType() const { if (IsDefined()) { - return node->GetType(); + return node->GetOpType(); } return OpType::Undefined; } +ExpressionType Operator::GetExpType() const { + if (IsDefined()) { + return node->GetExpType(); + } + return ExpressionType::INVALID; +} + bool Operator::IsLogical() const { if (IsDefined()) { return node->IsLogical(); diff --git a/src/optimizer/operators.cpp b/src/optimizer/operators.cpp index 52cf83f9a8c..17038d1b1eb 100644 --- a/src/optimizer/operators.cpp +++ b/src/optimizer/operators.cpp @@ -21,16 +21,16 @@ namespace optimizer { //===--------------------------------------------------------------------===// // Leaf //===--------------------------------------------------------------------===// -Operator LeafOperator::make(GroupID group) { +std::shared_ptr LeafOperator::make(GroupID group) { LeafOperator *op = new LeafOperator; op->origin_group = group; - return Operator(op); + return std::make_shared(std::shared_ptr(op)); } //===--------------------------------------------------------------------===// // Get //===--------------------------------------------------------------------===// -Operator LogicalGet::make(oid_t get_id, +std::shared_ptr LogicalGet::make(oid_t get_id, std::vector predicates, std::shared_ptr table, std::string alias, bool update) { @@ -41,19 +41,19 @@ Operator LogicalGet::make(oid_t get_id, get->is_for_update = update; get->get_id = get_id; util::to_lower_string(get->table_alias); - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } hash_t LogicalGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalGet::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::Get) return false; +bool LogicalGet::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::Get) return false; const LogicalGet &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -67,9 +67,9 @@ bool LogicalGet::operator==(const BaseOperatorNode &r) { // External file get //===--------------------------------------------------------------------===// -Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, - std::string file_name, char delimiter, - char quote, char escape) { +std::shared_ptr LogicalExternalFileGet::make( + oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, + char quote, char escape) { auto *get = new LogicalExternalFileGet(); get->get_id = get_id; get->format = format; @@ -77,11 +77,11 @@ Operator LogicalExternalFileGet::make(oid_t get_id, ExternalFileFormat format, get->delimiter = delimiter; get->quote = quote; get->escape = escape; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } -bool LogicalExternalFileGet::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::LogicalExternalFileGet) return false; +bool LogicalExternalFileGet::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::LogicalExternalFileGet) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && file_name == get.file_name && delimiter == get.delimiter && @@ -89,7 +89,7 @@ bool LogicalExternalFileGet::operator==(const BaseOperatorNode &node) { } hash_t LogicalExternalFileGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( @@ -103,7 +103,7 @@ hash_t LogicalExternalFileGet::Hash() const { //===--------------------------------------------------------------------===// // Query derived get //===--------------------------------------------------------------------===// -Operator LogicalQueryDerivedGet::make( +std::shared_ptr LogicalQueryDerivedGet::make( oid_t get_id, std::string &alias, std::unordered_map> @@ -113,18 +113,18 @@ Operator LogicalQueryDerivedGet::make( get->alias_to_expr_map = alias_to_expr_map; get->get_id = get_id; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } -bool LogicalQueryDerivedGet::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::LogicalQueryDerivedGet) return false; +bool LogicalQueryDerivedGet::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::LogicalQueryDerivedGet) return false; const LogicalQueryDerivedGet &r = *static_cast(&node); return get_id == r.get_id; } hash_t LogicalQueryDerivedGet::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); return hash; } @@ -132,21 +132,21 @@ hash_t LogicalQueryDerivedGet::Hash() const { //===--------------------------------------------------------------------===// // Select //===--------------------------------------------------------------------===// -Operator LogicalFilter::make(std::vector &filter) { +std::shared_ptr LogicalFilter::make(std::vector &filter) { LogicalFilter *select = new LogicalFilter; select->predicates = std::move(filter); - return Operator(select); + return std::make_shared(std::shared_ptr(select)); } hash_t LogicalFilter::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalFilter::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::LogicalFilter) return false; +bool LogicalFilter::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::LogicalFilter) return false; const LogicalFilter &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -158,38 +158,38 @@ bool LogicalFilter::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // Project //===--------------------------------------------------------------------===// -Operator LogicalProjection::make( +std::shared_ptr LogicalProjection::make( std::vector> &elements) { LogicalProjection *projection = new LogicalProjection; projection->expressions = std::move(elements); - return Operator(projection); + return std::make_shared(std::shared_ptr(projection)); } //===--------------------------------------------------------------------===// // DependentJoin //===--------------------------------------------------------------------===// -Operator LogicalDependentJoin::make() { +std::shared_ptr LogicalDependentJoin::make() { LogicalDependentJoin *join = new LogicalDependentJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalDependentJoin::make( +std::shared_ptr LogicalDependentJoin::make( std::vector &conditions) { LogicalDependentJoin *join = new LogicalDependentJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalDependentJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalDependentJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::LogicalDependentJoin) return false; +bool LogicalDependentJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::LogicalDependentJoin) return false; const LogicalDependentJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; @@ -204,27 +204,27 @@ bool LogicalDependentJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // MarkJoin //===--------------------------------------------------------------------===// -Operator LogicalMarkJoin::make() { +std::shared_ptr LogicalMarkJoin::make() { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalMarkJoin::make(std::vector &conditions) { +std::shared_ptr LogicalMarkJoin::make(std::vector &conditions) { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalMarkJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalMarkJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::LogicalMarkJoin) return false; +bool LogicalMarkJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::LogicalMarkJoin) return false; const LogicalMarkJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -238,27 +238,27 @@ bool LogicalMarkJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // SingleJoin //===--------------------------------------------------------------------===// -Operator LogicalSingleJoin::make() { +std::shared_ptr LogicalSingleJoin::make() { LogicalMarkJoin *join = new LogicalMarkJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalSingleJoin::make(std::vector &conditions) { +std::shared_ptr LogicalSingleJoin::make(std::vector &conditions) { LogicalSingleJoin *join = new LogicalSingleJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalSingleJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::LogicalSingleJoin) return false; +bool LogicalSingleJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::LogicalSingleJoin) return false; const LogicalSingleJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -272,27 +272,27 @@ bool LogicalSingleJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // InnerJoin //===--------------------------------------------------------------------===// -Operator LogicalInnerJoin::make() { +std::shared_ptr LogicalInnerJoin::make() { LogicalInnerJoin *join = new LogicalInnerJoin; join->join_predicates = {}; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } -Operator LogicalInnerJoin::make(std::vector &conditions) { +std::shared_ptr LogicalInnerJoin::make(std::vector &conditions) { LogicalInnerJoin *join = new LogicalInnerJoin; join->join_predicates = std::move(conditions); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t LogicalInnerJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : join_predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); return hash; } -bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerJoin) return false; +bool LogicalInnerJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::InnerJoin) return false; const LogicalInnerJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size()) return false; for (size_t i = 0; i < join_predicates.size(); i++) { @@ -306,70 +306,70 @@ bool LogicalInnerJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // LeftJoin //===--------------------------------------------------------------------===// -Operator LogicalLeftJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalLeftJoin::make(expression::AbstractExpression *condition) { LogicalLeftJoin *join = new LogicalLeftJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightJoin //===--------------------------------------------------------------------===// -Operator LogicalRightJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalRightJoin::make(expression::AbstractExpression *condition) { LogicalRightJoin *join = new LogicalRightJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterJoin //===--------------------------------------------------------------------===// -Operator LogicalOuterJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalOuterJoin::make(expression::AbstractExpression *condition) { LogicalOuterJoin *join = new LogicalOuterJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterJoin //===--------------------------------------------------------------------===// -Operator LogicalSemiJoin::make(expression::AbstractExpression *condition) { +std::shared_ptr LogicalSemiJoin::make(expression::AbstractExpression *condition) { LogicalSemiJoin *join = new LogicalSemiJoin; join->join_predicate = std::shared_ptr(condition); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // Aggregate //===--------------------------------------------------------------------===// -Operator LogicalAggregateAndGroupBy::make() { +std::shared_ptr LogicalAggregateAndGroupBy::make() { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = {}; - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } -Operator LogicalAggregateAndGroupBy::make( +std::shared_ptr LogicalAggregateAndGroupBy::make( std::vector> &columns) { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = move(columns); - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } -Operator LogicalAggregateAndGroupBy::make( +std::shared_ptr LogicalAggregateAndGroupBy::make( std::vector> &columns, std::vector &having) { LogicalAggregateAndGroupBy *group_by = new LogicalAggregateAndGroupBy; group_by->columns = move(columns); group_by->having = move(having); - return Operator(group_by); + return std::make_shared(std::shared_ptr(group_by)); } -bool LogicalAggregateAndGroupBy::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::LogicalAggregateAndGroupBy) return false; +bool LogicalAggregateAndGroupBy::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::LogicalAggregateAndGroupBy) return false; const LogicalAggregateAndGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -381,7 +381,7 @@ bool LogicalAggregateAndGroupBy::operator==(const BaseOperatorNode &node) { } hash_t LogicalAggregateAndGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; @@ -390,7 +390,7 @@ hash_t LogicalAggregateAndGroupBy::Hash() const { //===--------------------------------------------------------------------===// // Insert //===--------------------------------------------------------------------===// -Operator LogicalInsert::make( +std::shared_ptr LogicalInsert::make( std::shared_ptr target_table, const std::vector *columns, const std::vectortarget_table = target_table; insert_op->columns = columns; insert_op->values = values; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } -Operator LogicalInsertSelect::make( +std::shared_ptr LogicalInsertSelect::make( std::shared_ptr target_table) { LogicalInsertSelect *insert_op = new LogicalInsertSelect; insert_op->target_table = target_table; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // Delete //===--------------------------------------------------------------------===// -Operator LogicalDelete::make( +std::shared_ptr LogicalDelete::make( std::shared_ptr target_table) { LogicalDelete *delete_op = new LogicalDelete; delete_op->target_table = target_table; - return Operator(delete_op); + return std::make_shared(std::shared_ptr(delete_op)); } //===--------------------------------------------------------------------===// // Update //===--------------------------------------------------------------------===// -Operator LogicalUpdate::make( +std::shared_ptr LogicalUpdate::make( std::shared_ptr target_table, const std::vector> * updates) { LogicalUpdate *update_op = new LogicalUpdate; update_op->target_table = target_table; update_op->updates = updates; - return Operator(update_op); + return std::make_shared(std::shared_ptr(update_op)); } //===--------------------------------------------------------------------===// // Distinct //===--------------------------------------------------------------------===// -Operator LogicalDistinct::make() { - LogicalDistinct *distinct = new LogicalDistinct; - return Operator(distinct); +std::shared_ptr LogicalDistinct::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // Limit //===--------------------------------------------------------------------===// -Operator LogicalLimit::make( +std::shared_ptr LogicalLimit::make( int64_t offset, int64_t limit, std::vector &&sort_exprs, std::vector &&sort_ascending) { @@ -452,13 +451,13 @@ Operator LogicalLimit::make( limit_op->limit = limit; limit_op->sort_exprs = std::move(sort_exprs); limit_op->sort_ascending = std::move(sort_ascending); - return Operator(limit_op); + return std::make_shared(std::shared_ptr(limit_op)); } //===--------------------------------------------------------------------===// // External file output //===--------------------------------------------------------------------===// -Operator LogicalExportExternalFile::make(ExternalFileFormat format, +std::shared_ptr LogicalExportExternalFile::make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *export_op = new LogicalExportExternalFile(); @@ -467,11 +466,11 @@ Operator LogicalExportExternalFile::make(ExternalFileFormat format, export_op->delimiter = delimiter; export_op->quote = quote; export_op->escape = escape; - return Operator(export_op); + return std::make_shared(std::shared_ptr(export_op)); } -bool LogicalExportExternalFile::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::LogicalExportExternalFile) return false; +bool LogicalExportExternalFile::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::LogicalExportExternalFile) return false; const auto &export_op = *static_cast(&node); return (format == export_op.format && file_name == export_op.file_name && @@ -480,7 +479,7 @@ bool LogicalExportExternalFile::operator==(const BaseOperatorNode &node) { } hash_t LogicalExportExternalFile::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( hash, HashUtil::HashBytes(file_name.data(), file_name.length())); @@ -493,15 +492,14 @@ hash_t LogicalExportExternalFile::Hash() const { //===--------------------------------------------------------------------===// // DummyScan //===--------------------------------------------------------------------===// -Operator DummyScan::make() { - DummyScan *dummy = new DummyScan; - return Operator(dummy); +std::shared_ptr DummyScan::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // SeqScan //===--------------------------------------------------------------------===// -Operator PhysicalSeqScan::make( +std::shared_ptr PhysicalSeqScan::make( oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update) { @@ -513,11 +511,11 @@ Operator PhysicalSeqScan::make( scan->is_for_update = update; scan->get_id = get_id; - return Operator(scan); + return std::make_shared(std::shared_ptr(scan)); } -bool PhysicalSeqScan::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::SeqScan) return false; +bool PhysicalSeqScan::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::SeqScan) return false; const PhysicalSeqScan &node = *static_cast(&r); if (predicates.size() != node.predicates.size()) return false; for (size_t i = 0; i < predicates.size(); i++) { @@ -528,7 +526,7 @@ bool PhysicalSeqScan::operator==(const BaseOperatorNode &r) { } hash_t PhysicalSeqScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) hash = HashUtil::CombineHashes(hash, pred.expr->Hash()); @@ -538,7 +536,7 @@ hash_t PhysicalSeqScan::Hash() const { //===--------------------------------------------------------------------===// // IndexScan //===--------------------------------------------------------------------===// -Operator PhysicalIndexScan::make( +std::shared_ptr PhysicalIndexScan::make( oid_t get_id, std::shared_ptr table, std::string alias, std::vector predicates, bool update, oid_t index_id, std::vector key_column_id_list, @@ -556,11 +554,11 @@ Operator PhysicalIndexScan::make( scan->expr_type_list = std::move(expr_type_list); scan->value_list = std::move(value_list); - return Operator(scan); + return std::make_shared(std::shared_ptr(scan)); } -bool PhysicalIndexScan::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::IndexScan) return false; +bool PhysicalIndexScan::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::IndexScan) return false; const PhysicalIndexScan &node = *static_cast(&r); // TODO: Should also check value list if (index_id != node.index_id || @@ -577,7 +575,7 @@ bool PhysicalIndexScan::operator==(const BaseOperatorNode &r) { } hash_t PhysicalIndexScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&index_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); for (auto &pred : predicates) @@ -588,7 +586,7 @@ hash_t PhysicalIndexScan::Hash() const { //===--------------------------------------------------------------------===// // Physical external file scan //===--------------------------------------------------------------------===// -Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, +std::shared_ptr ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *get = new ExternalFileScan(); @@ -598,11 +596,11 @@ Operator ExternalFileScan::make(oid_t get_id, ExternalFileFormat format, get->delimiter = delimiter; get->quote = quote; get->escape = escape; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } -bool ExternalFileScan::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::QueryDerivedScan) return false; +bool ExternalFileScan::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::QueryDerivedScan) return false; const auto &get = *static_cast(&node); return (get_id == get.get_id && format == get.format && file_name == get.file_name && delimiter == get.delimiter && @@ -610,7 +608,7 @@ bool ExternalFileScan::operator==(const BaseOperatorNode &node) { } hash_t ExternalFileScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( @@ -624,7 +622,7 @@ hash_t ExternalFileScan::Hash() const { //===--------------------------------------------------------------------===// // Query derived get //===--------------------------------------------------------------------===// -Operator QueryDerivedScan::make( +std::shared_ptr QueryDerivedScan::make( oid_t get_id, std::string alias, std::unordered_map> @@ -634,17 +632,17 @@ Operator QueryDerivedScan::make( get->alias_to_expr_map = alias_to_expr_map; get->get_id = get_id; - return Operator(get); + return std::make_shared(std::shared_ptr(get)); } -bool QueryDerivedScan::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::QueryDerivedScan) return false; +bool QueryDerivedScan::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::QueryDerivedScan) return false; const QueryDerivedScan &r = *static_cast(&node); return get_id == r.get_id; } hash_t QueryDerivedScan::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&get_id)); return hash; } @@ -652,16 +650,14 @@ hash_t QueryDerivedScan::Hash() const { //===--------------------------------------------------------------------===// // OrderBy //===--------------------------------------------------------------------===// -Operator PhysicalOrderBy::make() { - PhysicalOrderBy *order_by = new PhysicalOrderBy; - - return Operator(order_by); +std::shared_ptr PhysicalOrderBy::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // PhysicalLimit //===--------------------------------------------------------------------===// -Operator PhysicalLimit::make( +std::shared_ptr PhysicalLimit::make( int64_t offset, int64_t limit, std::vector sort_exprs, std::vector sort_ascending) { @@ -670,13 +666,13 @@ Operator PhysicalLimit::make( limit_op->limit = limit; limit_op->sort_exprs = sort_exprs; limit_op->sort_acsending = sort_ascending; - return Operator(limit_op); + return std::make_shared(std::shared_ptr(limit_op)); } //===--------------------------------------------------------------------===// // InnerNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalInnerNLJoin::make( +std::shared_ptr PhysicalInnerNLJoin::make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { @@ -685,11 +681,11 @@ Operator PhysicalInnerNLJoin::make( join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t PhysicalInnerNLJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); for (auto &expr : right_keys) @@ -699,8 +695,8 @@ hash_t PhysicalInnerNLJoin::Hash() const { return hash; } -bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerNLJoin) return false; +bool PhysicalInnerNLJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::InnerNLJoin) return false; const PhysicalInnerNLJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || @@ -724,37 +720,37 @@ bool PhysicalInnerNLJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // LeftNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalLeftNLJoin::make( +std::shared_ptr PhysicalLeftNLJoin::make( std::shared_ptr join_predicate) { PhysicalLeftNLJoin *join = new PhysicalLeftNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalRightNLJoin::make( +std::shared_ptr PhysicalRightNLJoin::make( std::shared_ptr join_predicate) { PhysicalRightNLJoin *join = new PhysicalRightNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterNLJoin //===--------------------------------------------------------------------===// -Operator PhysicalOuterNLJoin::make( +std::shared_ptr PhysicalOuterNLJoin::make( std::shared_ptr join_predicate) { PhysicalOuterNLJoin *join = new PhysicalOuterNLJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // InnerHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalInnerHashJoin::make( +std::shared_ptr PhysicalInnerHashJoin::make( std::vector conditions, std::vector> &left_keys, std::vector> &right_keys) { @@ -762,11 +758,11 @@ Operator PhysicalInnerHashJoin::make( join->join_predicates = std::move(conditions); join->left_keys = std::move(left_keys); join->right_keys = std::move(right_keys); - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } hash_t PhysicalInnerHashJoin::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &expr : left_keys) hash = HashUtil::CombineHashes(hash, expr->Hash()); for (auto &expr : right_keys) @@ -776,8 +772,8 @@ hash_t PhysicalInnerHashJoin::Hash() const { return hash; } -bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { - if (r.GetType() != OpType::InnerHashJoin) return false; +bool PhysicalInnerHashJoin::operator==(const AbstractNode &r) { + if (r.GetOpType() != OpType::InnerHashJoin) return false; const PhysicalInnerHashJoin &node = *static_cast(&r); if (join_predicates.size() != node.join_predicates.size() || @@ -801,37 +797,37 @@ bool PhysicalInnerHashJoin::operator==(const BaseOperatorNode &r) { //===--------------------------------------------------------------------===// // LeftHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalLeftHashJoin::make( +std::shared_ptr PhysicalLeftHashJoin::make( std::shared_ptr join_predicate) { PhysicalLeftHashJoin *join = new PhysicalLeftHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // RightHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalRightHashJoin::make( +std::shared_ptr PhysicalRightHashJoin::make( std::shared_ptr join_predicate) { PhysicalRightHashJoin *join = new PhysicalRightHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // OuterHashJoin //===--------------------------------------------------------------------===// -Operator PhysicalOuterHashJoin::make( +std::shared_ptr PhysicalOuterHashJoin::make( std::shared_ptr join_predicate) { PhysicalOuterHashJoin *join = new PhysicalOuterHashJoin(); join->join_predicate = join_predicate; - return Operator(join); + return std::make_shared(std::shared_ptr(join)); } //===--------------------------------------------------------------------===// // PhysicalInsert //===--------------------------------------------------------------------===// -Operator PhysicalInsert::make( +std::shared_ptr PhysicalInsert::make( std::shared_ptr target_table, const std::vector *columns, const std::vectortarget_table = target_table; insert_op->columns = columns; insert_op->values = values; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // PhysicalInsertSelect //===--------------------------------------------------------------------===// -Operator PhysicalInsertSelect::make( +std::shared_ptr PhysicalInsertSelect::make( std::shared_ptr target_table) { PhysicalInsertSelect *insert_op = new PhysicalInsertSelect; insert_op->target_table = target_table; - return Operator(insert_op); + return std::make_shared(std::shared_ptr(insert_op)); } //===--------------------------------------------------------------------===// // PhysicalDelete //===--------------------------------------------------------------------===// -Operator PhysicalDelete::make( +std::shared_ptr PhysicalDelete::make( std::shared_ptr target_table) { PhysicalDelete *delete_op = new PhysicalDelete; delete_op->target_table = target_table; - return Operator(delete_op); + return std::make_shared(std::shared_ptr(delete_op)); } //===--------------------------------------------------------------------===// // PhysicalUpdate //===--------------------------------------------------------------------===// -Operator PhysicalUpdate::make( +std::shared_ptr PhysicalUpdate::make( std::shared_ptr target_table, const std::vector> * updates) { PhysicalUpdate *update = new PhysicalUpdate; update->target_table = target_table; update->updates = updates; - return Operator(update); + return std::make_shared(std::shared_ptr(update)); } //===--------------------------------------------------------------------===// // PhysicalExportExternalFile //===--------------------------------------------------------------------===// -Operator PhysicalExportExternalFile::make(ExternalFileFormat format, +std::shared_ptr PhysicalExportExternalFile::make(ExternalFileFormat format, std::string file_name, char delimiter, char quote, char escape) { auto *export_op = new PhysicalExportExternalFile(); @@ -888,11 +884,11 @@ Operator PhysicalExportExternalFile::make(ExternalFileFormat format, export_op->delimiter = delimiter; export_op->quote = quote; export_op->escape = escape; - return Operator(export_op); + return std::make_shared(std::shared_ptr(export_op)); } -bool PhysicalExportExternalFile::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::ExportExternalFile) return false; +bool PhysicalExportExternalFile::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::ExportExternalFile) return false; const auto &export_op = *static_cast(&node); return (format == export_op.format && file_name == export_op.file_name && @@ -901,7 +897,7 @@ bool PhysicalExportExternalFile::operator==(const BaseOperatorNode &node) { } hash_t PhysicalExportExternalFile::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); hash = HashUtil::CombineHashes(hash, HashUtil::Hash(&format)); hash = HashUtil::CombineHashes( hash, HashUtil::HashBytes(file_name.data(), file_name.length())); @@ -914,17 +910,17 @@ hash_t PhysicalExportExternalFile::Hash() const { //===--------------------------------------------------------------------===// // PhysicalHashGroupBy //===--------------------------------------------------------------------===// -Operator PhysicalHashGroupBy::make( +std::shared_ptr PhysicalHashGroupBy::make( std::vector> columns, std::vector having) { PhysicalHashGroupBy *agg = new PhysicalHashGroupBy; agg->columns = columns; agg->having = move(having); - return Operator(agg); + return std::make_shared(std::shared_ptr(agg)); } -bool PhysicalHashGroupBy::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::HashGroupBy) return false; +bool PhysicalHashGroupBy::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::HashGroupBy) return false; const PhysicalHashGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -936,7 +932,7 @@ bool PhysicalHashGroupBy::operator==(const BaseOperatorNode &node) { } hash_t PhysicalHashGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; @@ -945,17 +941,17 @@ hash_t PhysicalHashGroupBy::Hash() const { //===--------------------------------------------------------------------===// // PhysicalSortGroupBy //===--------------------------------------------------------------------===// -Operator PhysicalSortGroupBy::make( +std::shared_ptr PhysicalSortGroupBy::make( std::vector> columns, std::vector having) { PhysicalSortGroupBy *agg = new PhysicalSortGroupBy; agg->columns = std::move(columns); agg->having = move(having); - return Operator(agg); + return std::make_shared(std::shared_ptr(agg)); } -bool PhysicalSortGroupBy::operator==(const BaseOperatorNode &node) { - if (node.GetType() != OpType::SortGroupBy) return false; +bool PhysicalSortGroupBy::operator==(const AbstractNode &node) { + if (node.GetOpType() != OpType::SortGroupBy) return false; const PhysicalSortGroupBy &r = *static_cast(&node); if (having.size() != r.having.size() || columns.size() != r.columns.size()) @@ -967,7 +963,7 @@ bool PhysicalSortGroupBy::operator==(const BaseOperatorNode &node) { } hash_t PhysicalSortGroupBy::Hash() const { - hash_t hash = BaseOperatorNode::Hash(); + hash_t hash = AbstractNode::Hash(); for (auto &pred : having) hash = HashUtil::SumHashes(hash, pred.expr->Hash()); for (auto expr : columns) hash = HashUtil::SumHashes(hash, expr->Hash()); return hash; @@ -976,17 +972,15 @@ hash_t PhysicalSortGroupBy::Hash() const { //===--------------------------------------------------------------------===// // PhysicalAggregate //===--------------------------------------------------------------------===// -Operator PhysicalAggregate::make() { - PhysicalAggregate *agg = new PhysicalAggregate; - return Operator(agg); +std::shared_ptr PhysicalAggregate::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// // Physical Hash //===--------------------------------------------------------------------===// -Operator PhysicalDistinct::make() { - PhysicalDistinct *hash = new PhysicalDistinct; - return Operator(hash); +std::shared_ptr PhysicalDistinct::make() { + return std::make_shared(); } //===--------------------------------------------------------------------===// @@ -1099,113 +1093,118 @@ std::string OperatorNode::name_ = //===--------------------------------------------------------------------===// template <> -OpType OperatorNode::type_ = OpType::Leaf; +OpType OperatorNode::op_type_ = OpType::Leaf; template <> -OpType OperatorNode::type_ = OpType::Get; +OpType OperatorNode::op_type_ = OpType::Get; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalExternalFileGet; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalQueryDerivedGet; template <> -OpType OperatorNode::type_ = OpType::LogicalFilter; +OpType OperatorNode::op_type_ = OpType::LogicalFilter; template <> -OpType OperatorNode::type_ = OpType::LogicalProjection; +OpType OperatorNode::op_type_ = OpType::LogicalProjection; template <> -OpType OperatorNode::type_ = OpType::LogicalMarkJoin; +OpType OperatorNode::op_type_ = OpType::LogicalMarkJoin; template <> -OpType OperatorNode::type_ = OpType::LogicalSingleJoin; +OpType OperatorNode::op_type_ = OpType::LogicalSingleJoin; template <> -OpType OperatorNode::type_ = OpType::LogicalDependentJoin; +OpType OperatorNode::op_type_ = OpType::LogicalDependentJoin; template <> -OpType OperatorNode::type_ = OpType::InnerJoin; +OpType OperatorNode::op_type_ = OpType::InnerJoin; template <> -OpType OperatorNode::type_ = OpType::LeftJoin; +OpType OperatorNode::op_type_ = OpType::LeftJoin; template <> -OpType OperatorNode::type_ = OpType::RightJoin; +OpType OperatorNode::op_type_ = OpType::RightJoin; template <> -OpType OperatorNode::type_ = OpType::OuterJoin; +OpType OperatorNode::op_type_ = OpType::OuterJoin; template <> -OpType OperatorNode::type_ = OpType::SemiJoin; +OpType OperatorNode::op_type_ = OpType::SemiJoin; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalAggregateAndGroupBy; template <> -OpType OperatorNode::type_ = OpType::LogicalInsert; +OpType OperatorNode::op_type_ = OpType::LogicalInsert; template <> -OpType OperatorNode::type_ = OpType::LogicalInsertSelect; +OpType OperatorNode::op_type_ = OpType::LogicalInsertSelect; template <> -OpType OperatorNode::type_ = OpType::LogicalUpdate; +OpType OperatorNode::op_type_ = OpType::LogicalUpdate; template <> -OpType OperatorNode::type_ = OpType::LogicalDelete; +OpType OperatorNode::op_type_ = OpType::LogicalDelete; template <> -OpType OperatorNode::type_ = OpType::LogicalDistinct; +OpType OperatorNode::op_type_ = OpType::LogicalDistinct; template <> -OpType OperatorNode::type_ = OpType::LogicalLimit; +OpType OperatorNode::op_type_ = OpType::LogicalLimit; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::LogicalExportExternalFile; template <> -OpType OperatorNode::type_ = OpType::DummyScan; +OpType OperatorNode::op_type_ = OpType::DummyScan; template <> -OpType OperatorNode::type_ = OpType::SeqScan; +OpType OperatorNode::op_type_ = OpType::SeqScan; template <> -OpType OperatorNode::type_ = OpType::IndexScan; +OpType OperatorNode::op_type_ = OpType::IndexScan; template <> -OpType OperatorNode::type_ = OpType::ExternalFileScan; +OpType OperatorNode::op_type_ = OpType::ExternalFileScan; template <> -OpType OperatorNode::type_ = OpType::QueryDerivedScan; +OpType OperatorNode::op_type_ = OpType::QueryDerivedScan; template <> -OpType OperatorNode::type_ = OpType::OrderBy; +OpType OperatorNode::op_type_ = OpType::OrderBy; template <> -OpType OperatorNode::type_ = OpType::Distinct; +OpType OperatorNode::op_type_ = OpType::Distinct; template <> -OpType OperatorNode::type_ = OpType::PhysicalLimit; +OpType OperatorNode::op_type_ = OpType::PhysicalLimit; template <> -OpType OperatorNode::type_ = OpType::InnerNLJoin; +OpType OperatorNode::op_type_ = OpType::InnerNLJoin; template <> -OpType OperatorNode::type_ = OpType::LeftNLJoin; +OpType OperatorNode::op_type_ = OpType::LeftNLJoin; template <> -OpType OperatorNode::type_ = OpType::RightNLJoin; +OpType OperatorNode::op_type_ = OpType::RightNLJoin; template <> -OpType OperatorNode::type_ = OpType::OuterNLJoin; +OpType OperatorNode::op_type_ = OpType::OuterNLJoin; template <> -OpType OperatorNode::type_ = OpType::InnerHashJoin; +OpType OperatorNode::op_type_ = OpType::InnerHashJoin; template <> -OpType OperatorNode::type_ = OpType::LeftHashJoin; +OpType OperatorNode::op_type_ = OpType::LeftHashJoin; template <> -OpType OperatorNode::type_ = OpType::RightHashJoin; +OpType OperatorNode::op_type_ = OpType::RightHashJoin; template <> -OpType OperatorNode::type_ = OpType::OuterHashJoin; +OpType OperatorNode::op_type_ = OpType::OuterHashJoin; template <> -OpType OperatorNode::type_ = OpType::Insert; +OpType OperatorNode::op_type_ = OpType::Insert; template <> -OpType OperatorNode::type_ = OpType::InsertSelect; +OpType OperatorNode::op_type_ = OpType::InsertSelect; template <> -OpType OperatorNode::type_ = OpType::Delete; +OpType OperatorNode::op_type_ = OpType::Delete; template <> -OpType OperatorNode::type_ = OpType::Update; +OpType OperatorNode::op_type_ = OpType::Update; template <> -OpType OperatorNode::type_ = OpType::HashGroupBy; +OpType OperatorNode::op_type_ = OpType::HashGroupBy; template <> -OpType OperatorNode::type_ = OpType::SortGroupBy; +OpType OperatorNode::op_type_ = OpType::SortGroupBy; template <> -OpType OperatorNode::type_ = OpType::Aggregate; +OpType OperatorNode::op_type_ = OpType::Aggregate; template <> -OpType OperatorNode::type_ = +OpType OperatorNode::op_type_ = OpType::ExportExternalFile; + +//===--------------------------------------------------------------------===// +template +ExpressionType OperatorNode::exp_type_ = ExpressionType::INVALID; + //===--------------------------------------------------------------------===// template bool OperatorNode::IsLogical() const { - return type_ < OpType::LogicalPhysicalDelimiter; + return op_type_ < OpType::LogicalPhysicalDelimiter; } template bool OperatorNode::IsPhysical() const { - return type_ > OpType::LogicalPhysicalDelimiter; + return op_type_ > OpType::LogicalPhysicalDelimiter; } template <> diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 83bcadde4de..e7f648e9375 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -327,7 +327,7 @@ const std::string Optimizer::GetOperatorInfo( auto gexpr = group->GetBestExpression(required_props); os << std::endl << StringUtil::Indent(num_indent) << "operator name: " - << gexpr->Op().GetName().c_str(); + << gexpr->Node()->GetName().c_str(); vector child_groups = gexpr->GetChildGroupIDs(); auto required_input_props = gexpr->GetInputProperties(required_props); @@ -352,7 +352,7 @@ unique_ptr Optimizer::ChooseBestPlan( auto gexpr = group->GetBestExpression(required_props); LOG_TRACE("Choosing best plan for group %d with op %s", gexpr->GetGroupID(), - gexpr->Op().GetName().c_str()); + gexpr->Op()->GetName().c_str()); vector child_groups = gexpr->GetChildGroupIDs(); auto required_input_props = gexpr->GetInputProperties(required_props); @@ -383,8 +383,7 @@ unique_ptr Optimizer::ChooseBestPlan( } // Derive root plan - shared_ptr op = - make_shared(gexpr->Op()); + std::shared_ptr op(make_shared(gexpr->Node())); PlanGenerator generator; auto plan = generator.ConvertOpExpression(op, required_props, required_cols, diff --git a/src/optimizer/optimizer_task.cpp b/src/optimizer/optimizer_task.cpp index e1cfac5643d..24edd6d6876 100644 --- a/src/optimizer/optimizer_task.cpp +++ b/src/optimizer/optimizer_task.cpp @@ -18,6 +18,7 @@ #include "optimizer/child_property_deriver.h" #include "optimizer/stats/stats_calculator.h" #include "optimizer/stats/child_stats_deriver.h" +#include "optimizer/absexpr_expression.h" namespace peloton { namespace optimizer { @@ -31,8 +32,9 @@ void OptimizerTask::ConstructValidRules( std::vector &valid_rules) { for (auto &rule : rules) { // Check if we can apply the rule - bool root_pattern_mismatch = - group_expr->Op().GetType() != rule->GetMatchPattern()->Type(); + bool root_pattern_mismatch = group_expr->Node()->GetOpType() != rule->GetMatchPattern()->GetOpType() || + group_expr->Node()->GetExpType() != rule->GetMatchPattern()->GetExpType(); + bool already_explored = group_expr->HasRuleExplored(rule.get()); bool child_pattern_mismatch = group_expr->GetChildrenGroupsSize() != @@ -89,15 +91,15 @@ void OptimizeExpression::execute() { std::vector valid_rules; // Construct valid transformation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetTransformationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetTransformationRules(), valid_rules); // Construct valid implementation rules from rule set - ConstructValidRules(group_expr_, context_.get(), - GetRuleSet().GetImplementationRules(), valid_rules); + this->ConstructValidRules(group_expr_, context_.get(), + GetRuleSet().GetImplementationRules(), valid_rules); std::sort(valid_rules.begin(), valid_rules.end()); LOG_DEBUG("OptimizeExpression::execute() op %d, valid rules : %lu", - static_cast(group_expr_->Op().GetType()), valid_rules.size()); + static_cast(group_expr_->Node()->GetOpType()), valid_rules.size()); // Apply rule for (auto &r : valid_rules) { PushTask(new ApplyRule(group_expr_, r.rule, context_)); @@ -172,22 +174,21 @@ void ApplyRule::execute() { LOG_TRACE("ApplyRule::execute() for rule: %d", rule_->GetRuleIdx()); if (group_expr_->HasRuleExplored(rule_)) return; - GroupExprBindingIterator iterator(GetMemo(), group_expr_, - rule_->GetMatchPattern()); + GroupExprBindingIterator iterator(GetMemo(), group_expr_, rule_->GetMatchPattern()); while (iterator.HasNext()) { auto before = iterator.Next(); if (!rule_->Check(before, context_.get())) { continue; } - std::vector> after; + std::vector> after; rule_->Transform(before, after, context_.get()); for (auto &new_expr : after) { std::shared_ptr new_gexpr; if (context_->metadata->RecordTransformedExpression( new_expr, new_gexpr, group_expr_->GetGroupID())) { // A new group expression is generated - if (new_gexpr->Op().IsLogical()) { + if (new_gexpr->Node()->IsLogical()) { // Derive stats for the *logical expression* PushTask(new DeriveStats(new_gexpr.get(), ExprSet{}, context_)); if (explore_only) { @@ -402,106 +403,141 @@ void OptimizeInputs::execute() { } } -void TopDownRewrite::execute() { - std::vector valid_rules; +// =================================================================== +// +// RewriteTask related implementations +// +// =================================================================== +std::set RewriteTask::GetUniqueChildGroupIDs() { + // Get current group and logical expressions + auto cur_group = this->GetMemo().GetGroupByID(group_id_); + auto cur_group_exprs = cur_group->GetLogicalExpressions(); + PELOTON_ASSERT(cur_group_exprs.size() >= 1); + + // Generate unique group ID numbers so we don't repeat work + std::set child_groups; + for (auto cur_group_expr : cur_group_exprs) { + for (size_t child = 0; child < cur_group_expr->GetChildrenGroupsSize(); child++) { + child_groups.insert(cur_group_expr->GetChildGroupId(child)); + } + } - auto cur_group = GetMemo().GetGroupByID(group_id_); - auto cur_group_expr = cur_group->GetLogicalExpression(); + return child_groups; +} - // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); +bool RewriteTask::OptimizeCurrentGroup(bool replace_on_match) { + std::vector valid_rules; - // Sort so that we apply rewrite rules with higher promise first - std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); + // Get current group and logical expressions + auto cur_group = this->GetMemo().GetGroupByID(group_id_); + auto cur_group_exprs = cur_group->GetLogicalExpressions(); + PELOTON_ASSERT(cur_group_exprs.size() >= 1); + + // Try to optimize all the logical group expressions. + // If one gets optimized, then the group is collapsed. + for (auto cur_group_expr_ptr : cur_group_exprs) { + auto cur_group_expr = cur_group_expr_ptr.get(); + + // Construct valid transformation rules from rule set + this->ConstructValidRules(cur_group_expr, this->context_.get(), + this->GetRuleSet().GetRewriteRulesByName(rule_set_name_), + valid_rules); + + // Sort so that we apply rewrite rules with higher promise first + std::sort(valid_rules.begin(), valid_rules.end(), + std::greater()); + + // Try applying each rule + for (auto &r : valid_rules) { + GroupExprBindingIterator iterator(this->GetMemo(), cur_group_expr, r.rule->GetMatchPattern()); + // Keep trying to apply until we exhaust all the bindings. + // This could possibly be sub-optimal since the first binding that results + // in a transformation by a rule will be applied and become the group's + // "new" rewritten expression. + while (iterator.HasNext()) { + // Binding succeeded to a given expression structure + auto before = iterator.Next(); + + // Attempt to apply the transformation + std::vector> after; + r.rule->Transform(before, after, this->context_.get()); + + // Rewrite rule should provide at most 1 expression + PELOTON_ASSERT(after.size() <= 1); + if (!after.empty()) { + // The transformation produced another expression + auto &new_expr = after[0]; + if (replace_on_match) { + // Replace entire group. We do not need to generate logically equivalent + // because rewriting expressions will not generate new AND or OR clauses. + this->context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); + + // Return true to indicate optimize succeeded and the caller should try again + return true; + } else { + // Insert as a new logical equivalent expression + std::shared_ptr new_gexpr; + GroupID group = cur_group_expr->GetGroupID(); - for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); - if (iterator.HasNext()) { - auto before = iterator.Next(); - PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); - - // Rewrite rule should provide at most 1 expression - PELOTON_ASSERT(after.size() <= 1); - // If a rule is applied, we replace the old expression and optimize this - // group again, this will ensure that we apply rule for this level until - // saturated - if (!after.empty()) { - auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask(new TopDownRewrite(group_id_, context_, rule_set_name_)); - return; + // Try again only if we succeeded in recording a new expression + return this->context_->metadata->RecordTransformedExpression(new_expr, new_gexpr, group); + } + } } + + cur_group_expr->SetRuleExplored(r.rule); } - cur_group_expr->SetRuleExplored(r.rule); } - for (size_t child_group_idx = 0; - child_group_idx < cur_group_expr->GetChildrenGroupsSize(); - child_group_idx++) { - // Need to rewrite all sub trees first - PushTask( - new TopDownRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_)); - } + return false; } -void BottomUpRewrite::execute() { - std::vector valid_rules; +void TopDownRewrite::execute() { + bool did_optimize = this->OptimizeCurrentGroup(replace_on_transform_); + + // Optimize succeeded and by the design, there will ever be only 1 + // that is logically equivalent, so we do not need to perform + // any extra passes. Equivalence generating rules will not be repeatedly + // applied to expression trees. + // + // This is definitely sub-optimal and is a missed opportunity for rewrite. + // However, this requires AbstractExpression to support strict equality + // in its post-binding state. + if (did_optimize && replace_on_transform_) { + auto top = new TopDownRewrite(this->group_id_, this->context_, this->rule_set_name_); + top->SetReplaceOnTransform(replace_on_transform_); + this->PushTask(top); + return; + } - auto cur_group = GetMemo().GetGroupByID(group_id_); - auto cur_group_expr = cur_group->GetLogicalExpression(); + // This group has been optimized, so move on to the children + std::set child_groups = this->GetUniqueChildGroupIDs(); + for (auto g_id : child_groups) { + auto top = new TopDownRewrite(g_id, this->context_, this->rule_set_name_); + top->SetReplaceOnTransform(replace_on_transform_); + this->PushTask(top); + } +} +void BottomUpRewrite::execute() { if (!has_optimized_child_) { - PushTask(new BottomUpRewrite(group_id_, context_, rule_set_name_, true)); - for (size_t child_group_idx = 0; - child_group_idx < cur_group_expr->GetChildrenGroupsSize(); - child_group_idx++) { - // Need to rewrite all sub trees first - PushTask( - new BottomUpRewrite(cur_group_expr->GetChildGroupId(child_group_idx), - context_, rule_set_name_, false)); + this->PushTask(new BottomUpRewrite(this->group_id_, this->context_, this->rule_set_name_, true)); + + // Get all unique GroupIDs to minimize repeated work + // Need to rewrite all sub trees first + std::set child_groups = this->GetUniqueChildGroupIDs(); + for (auto g_id : child_groups) { + this->PushTask(new BottomUpRewrite(g_id, this->context_, this->rule_set_name_, false)); } + return; } - // Construct valid transformation rules from rule set - ConstructValidRules(cur_group_expr, context_.get(), - GetRuleSet().GetRewriteRulesByName(rule_set_name_), - valid_rules); - - // Sort so that we apply rewrite rules with higher promise first - std::sort(valid_rules.begin(), valid_rules.end(), - std::greater()); - for (auto &r : valid_rules) { - GroupExprBindingIterator iterator(GetMemo(), cur_group_expr, - r.rule->GetMatchPattern()); - if (iterator.HasNext()) { - auto before = iterator.Next(); - PELOTON_ASSERT(!iterator.HasNext()); - std::vector> after; - r.rule->Transform(before, after, context_.get()); - - // Rewrite rule should provide at most 1 expression - PELOTON_ASSERT(after.size() <= 1); - // If a rule is applied, we replace the old expression and optimize this - // group again, this will ensure that we apply rule for this level until - // saturated, also childs are already been rewritten - if (!after.empty()) { - auto &new_expr = after[0]; - context_->metadata->ReplaceRewritedExpression(new_expr, group_id_); - PushTask( - new BottomUpRewrite(group_id_, context_, rule_set_name_, false)); - return; - } - } - cur_group_expr->SetRuleExplored(r.rule); + // Keep rewriting until we finish + if (this->OptimizeCurrentGroup(true)) { + this->PushTask(new BottomUpRewrite(this->group_id_, this->context_, this->rule_set_name_, false)); } } + } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/pattern.cpp b/src/optimizer/pattern.cpp index d7665d678bb..81fd8b7d321 100644 --- a/src/optimizer/pattern.cpp +++ b/src/optimizer/pattern.cpp @@ -15,7 +15,8 @@ namespace peloton { namespace optimizer { -Pattern::Pattern(OpType op) : _type(op) {} +Pattern::Pattern(OpType op) : _op_type(op) {} +Pattern::Pattern(ExpressionType exp) : _exp_type(exp) {} void Pattern::AddChild(std::shared_ptr child) { children.push_back(child); @@ -25,7 +26,8 @@ const std::vector> &Pattern::Children() const { return children; } -OpType Pattern::Type() const { return _type; } +OpType Pattern::GetOpType() const { return _op_type; } +ExpressionType Pattern::GetExpType() const { return _exp_type; } } // namespace optimizer } // namespace peloton diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index c04c8025903..6c83121d03b 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -66,7 +66,7 @@ unique_ptr PlanGenerator::ConvertOpExpression( output_cols_ = move(output_cols); children_plans_ = move(children_plans); children_expr_map_ = move(children_expr_map); - op->Op().Accept(this); + op->Node()->Accept(this); BuildProjectionPlan(); output_plan_->SetCardinality(estimated_cardinality); return move(output_plan_); diff --git a/src/optimizer/property_enforcer.cpp b/src/optimizer/property_enforcer.cpp index 834cf9a76d7..6ea66eaa09a 100644 --- a/src/optimizer/property_enforcer.cpp +++ b/src/optimizer/property_enforcer.cpp @@ -21,25 +21,26 @@ namespace optimizer { std::shared_ptr PropertyEnforcer::EnforceProperty( GroupExpression* gexpr, Property* property) { + input_gexpr_ = gexpr; property->Accept(this); return output_gexpr_; } -void PropertyEnforcer::Visit(const PropertyColumns *) { - -} +void PropertyEnforcer::Visit(const PropertyColumns *) {} void PropertyEnforcer::Visit(const PropertySort *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalOrderBy::make(), child_groups); + std::make_shared( + std::make_shared(PhysicalOrderBy::make()), child_groups); } void PropertyEnforcer::Visit(const PropertyDistinct *) { std::vector child_groups(1, input_gexpr_->GetGroupID()); output_gexpr_ = - std::make_shared(PhysicalDistinct::make(), child_groups); + std::make_shared( + std::make_shared(PhysicalOrderBy::make()), child_groups); } void PropertyEnforcer::Visit(const PropertyLimit *) {} diff --git a/src/optimizer/rewriter.cpp b/src/optimizer/rewriter.cpp new file mode 100644 index 00000000000..5378287dbe7 --- /dev/null +++ b/src/optimizer/rewriter.cpp @@ -0,0 +1,147 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rewriter.cpp +// +// Identification: src/optimizer/rewriter.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include + +#include "optimizer/optimizer.h" +#include "optimizer/rewriter.h" +#include "common/exception.h" + +#include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/operator_visitor.h" +#include "optimizer/optimize_context.h" +#include "optimizer/optimizer_task_pool.h" +#include "optimizer/rule.h" +#include "optimizer/rule_impls.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/absexpr_expression.h" +#include "expression/abstract_expression.h" +#include "expression/constant_value_expression.h" + +using std::vector; +using std::unordered_map; +using std::shared_ptr; +using std::unique_ptr; +using std::move; +using std::pair; +using std::make_shared; + +namespace peloton { +namespace optimizer { + +Rewriter::Rewriter() : metadata_(nullptr) { + metadata_ = OptimizerMetadata(nullptr); +} + +void Rewriter::Reset() { + metadata_ = OptimizerMetadata(nullptr); +} + +void Rewriter::RewriteLoop(int root_group_id) { + std::shared_ptr root_context = std::make_shared(&metadata_, nullptr); + auto task_stack = std::unique_ptr(new OptimizerTaskStack()); + metadata_.SetTaskPool(task_stack.get()); + + // Rewrite using all rules (which will be applied based on priority) + task_stack->Push(new BottomUpRewrite(root_group_id, root_context, RewriteRuleSetName::GENERIC_RULES, false)); + + // Generate equivalences first + auto equiv_task = new TopDownRewrite(root_group_id, root_context, RewriteRuleSetName::EQUIVALENT_TRANSFORM); + equiv_task->SetReplaceOnTransform(false); // generate equivalent + task_stack->Push(equiv_task); + + // Iterate through the task stack + while (!task_stack->Empty()) { + auto task = task_stack->Pop(); + task->execute(); + } +} + +expression::AbstractExpression* Rewriter::RebuildExpression(int root) { + auto cur_group = metadata_.memo.GetGroupByID(root); + auto exprs = cur_group->GetLogicalExpressions(); + + // If we optimized a group successfully, then it would have been + // collapsed to only a single group. If we did not optimize a group, + // then they are all equivalent, so pick any. + PELOTON_ASSERT(exprs.size() >= 1); + auto expr = exprs[0]; + + std::vector child_groups = expr->GetChildGroupIDs(); + std::vector child_exprs; + for (auto group : child_groups) { + // Build children first + expression::AbstractExpression *child = RebuildExpression(group); + PELOTON_ASSERT(child != nullptr); + + child_exprs.push_back(child); + } + + std::shared_ptr c = std::dynamic_pointer_cast(expr->Node()); + PELOTON_ASSERT(c != nullptr); + + return c->CopyWithChildren(child_exprs); +} + +std::shared_ptr Rewriter::ConvertToAbsExpr(const expression::AbstractExpression* expr) { + // TODO(): remove the Copy invocation when in terrier since terrier uses shared_ptr + // + // This Copy() is not very efficient at all. but within Peloton, this is the only way + // to present a std::shared_ptr to the AbsExprNode/Expression classes. In terrier, + // this Copy() is *definitely* not needed because the AbstractExpression there already + // utilizes std::shared_ptr properly. + std::shared_ptr copy = std::shared_ptr(expr->Copy()); + + // Create current AbsExprExpression + auto container = std::make_shared(copy); + auto expression = std::make_shared(container); + + // Convert all the children + size_t child_count = expr->GetChildrenSize(); + for (size_t i = 0; i < child_count; i++) { + expression->PushChild(ConvertToAbsExpr(expr->GetChild(i))); + } + + copy->ClearChildren(); + return expression; +} + +std::shared_ptr Rewriter::RecordTreeGroups(const expression::AbstractExpression *expr) { + std::shared_ptr exp = ConvertToAbsExpr(expr); + std::shared_ptr gexpr; + metadata_.RecordTransformedExpression(exp, gexpr); + return gexpr; +} + +expression::AbstractExpression* Rewriter::RewriteExpression(const expression::AbstractExpression *expr) { + if (expr == nullptr) + return nullptr; + + // This is needed in order to provide template classes the correct interface + // and also handle immutable AbstractExpression. + std::shared_ptr gexpr = RecordTreeGroups(expr); + LOG_DEBUG("Converted tree to internal data structures"); + + GroupID root_id = gexpr->GetGroupID(); + RewriteLoop(root_id); + LOG_DEBUG("Performed rewrite loop pass"); + + expression::AbstractExpression *expr_tree = RebuildExpression(root_id); + LOG_DEBUG("Rebuilt expression tree from memo table"); + + Reset(); + LOG_DEBUG("Reset the rewriter"); + return expr_tree; +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index 8c72ed17fa8..bb2518d8732 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -12,22 +12,79 @@ #include "optimizer/rule_impls.h" #include "optimizer/group_expression.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/rule_rewrite.h" namespace peloton { namespace optimizer { int Rule::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); + auto root_type_exp = match_pattern->GetExpType(); + // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Undefined && + root_type != OpType::Leaf && + root_type != group_expr->Node()->GetOpType()) { + return 0; + } + + if (root_type_exp != ExpressionType::INVALID && + root_type_exp != ExpressionType::GROUP_MARKER && + root_type_exp != group_expr->Node()->GetExpType()) { return 0; } + if (IsPhysical()) return PHYS_PROMISE; return LOG_PROMISE; } RuleSet::RuleSet() { + // Comparator Elimination related rules + std::vector> comp_elim_pairs = { + std::make_pair(RuleType::CONSTANT_COMPARE_EQUAL, ExpressionType::COMPARE_EQUAL), + std::make_pair(RuleType::CONSTANT_COMPARE_NOTEQUAL, ExpressionType::COMPARE_NOTEQUAL), + std::make_pair(RuleType::CONSTANT_COMPARE_LESSTHAN, ExpressionType::COMPARE_LESSTHAN), + std::make_pair(RuleType::CONSTANT_COMPARE_GREATERTHAN, ExpressionType::COMPARE_GREATERTHAN), + std::make_pair(RuleType::CONSTANT_COMPARE_LESSTHANOREQUALTO, ExpressionType::COMPARE_LESSTHANOREQUALTO), + std::make_pair(RuleType::CONSTANT_COMPARE_GREATERTHANOREQUALTO, ExpressionType::COMPARE_GREATERTHANOREQUALTO) + }; + + for (auto &pair : comp_elim_pairs) { + AddRewriteRule( + RewriteRuleSetName::GENERIC_RULES, + new ComparatorElimination(pair.first, pair.second) + ); + } + + // Equivalent Transform related rules (flip AND, OR, EQUAL) + std::vector> equiv_pairs = { + std::make_pair(RuleType::EQUIV_AND, ExpressionType::CONJUNCTION_AND), + std::make_pair(RuleType::EQUIV_OR, ExpressionType::CONJUNCTION_OR), + std::make_pair(RuleType::EQUIV_COMPARE_EQUAL, ExpressionType::COMPARE_EQUAL) + }; + for (auto &pair : equiv_pairs) { + AddRewriteRule( + RewriteRuleSetName::EQUIVALENT_TRANSFORM, + new EquivalentTransform(pair.first, pair.second) + ); + } + + // Additional rules + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new AndShortCircuit()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new OrShortCircuit()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NullLookupOnNotNullColumn()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new NotNullLookupOnNotNullColumn()); + + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TVEqualityWithTwoCVTransform()); + AddRewriteRule(RewriteRuleSetName::GENERIC_RULES, new TransitiveClosureConstantTransform()); + + // Define transformation/implementation rules AddTransformationRule(new InnerJoinCommutativity()); AddTransformationRule(new InnerJoinAssociativity()); AddImplementationRule(new LogicalDeleteToPhysical()); @@ -47,6 +104,7 @@ RuleSet::RuleSet() { AddImplementationRule(new ImplementLimit()); AddImplementationRule(new LogicalExportToPhysicalExport()); + // Query optimizer related rewrite rules AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, new PushFilterThroughJoin()); AddRewriteRule(RewriteRuleSetName::PREDICATE_PUSH_DOWN, diff --git a/src/optimizer/rule_impls.cpp b/src/optimizer/rule_impls.cpp index 8574e00f337..530bfedee99 100644 --- a/src/optimizer/rule_impls.cpp +++ b/src/optimizer/rule_impls.cpp @@ -41,7 +41,7 @@ InnerJoinCommutativity::InnerJoinCommutativity() { match_pattern->AddChild(right_child); } -bool InnerJoinCommutativity::Check(std::shared_ptr expr, +bool InnerJoinCommutativity::Check(std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; @@ -49,19 +49,21 @@ bool InnerJoinCommutativity::Check(std::shared_ptr expr, } void InnerJoinCommutativity::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto join_op = input->Op().As(); + auto join_op = input->Node()->As(); auto join_predicates = std::vector(join_op->join_predicates); + + auto result_plan = std::make_shared( LogicalInnerJoin::make(join_predicates)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); LOG_TRACE( "Reorder left child with op %s and right child with op %s for inner join", - children[0]->Op().GetName().c_str(), children[1]->Op().GetName().c_str()); + children[0]->Node()->GetName().c_str(), children[1]->Node()->GetName().c_str()); result_plan->PushChild(children[1]); result_plan->PushChild(children[0]); @@ -86,7 +88,7 @@ InnerJoinAssociativity::InnerJoinAssociativity() { } // TODO: As far as I know, theres nothing else that needs to be checked -bool InnerJoinAssociativity::Check(std::shared_ptr expr, +bool InnerJoinAssociativity::Check(std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; @@ -94,29 +96,29 @@ bool InnerJoinAssociativity::Check(std::shared_ptr expr, } void InnerJoinAssociativity::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { // NOTE: Transforms (left JOIN middle) JOIN right -> left JOIN (middle JOIN // right) Variables are named accordingly to above transformation - auto parent_join = input->Op().As(); - std::vector> children = input->Children(); + auto parent_join = input->Node()->As(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 2); - PELOTON_ASSERT(children[0]->Op().GetType() == OpType::InnerJoin); + PELOTON_ASSERT(children[0]->Node()->GetOpType() == OpType::InnerJoin); PELOTON_ASSERT(children[0]->Children().size() == 2); - auto child_join = children[0]->Op().As(); + auto child_join = children[0]->Node()->As(); auto left = children[0]->Children()[0]; auto middle = children[0]->Children()[1]; auto right = children[1]; LOG_DEBUG("Reordered join structured: (%s JOIN %s) JOIN %s", - left->Op().GetName().c_str(), middle->Op().GetName().c_str(), - right->Op().GetName().c_str()); + left->Node()->GetName().c_str(), middle->Node()->GetName().c_str(), + right->Node()->GetName().c_str()); // Get Alias sets auto &memo = context->metadata->memo; - auto middle_group_id = middle->Op().As()->origin_group; - auto right_group_id = right->Op().As()->origin_group; + auto middle_group_id = middle->Node()->As()->origin_group; + auto right_group_id = right->Node()->As()->origin_group; const auto &middle_group_aliases_set = memo.GetGroupByID(middle_group_id)->GetTableAliases(); @@ -154,14 +156,14 @@ void InnerJoinAssociativity::Transform( } // Construct new child join operator - std::shared_ptr new_child_join = + std::shared_ptr new_child_join = std::make_shared( LogicalInnerJoin::make(new_child_join_predicates)); new_child_join->PushChild(middle); new_child_join->PushChild(right); // Construct new parent join operator - std::shared_ptr new_parent_join = + std::shared_ptr new_parent_join = std::make_shared( LogicalInnerJoin::make(new_parent_join_predicates)); new_parent_join->PushChild(left); @@ -182,16 +184,16 @@ GetToDummyScan::GetToDummyScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToDummyScan::Check(std::shared_ptr plan, +bool GetToDummyScan::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); return get->table == nullptr; } void GetToDummyScan::Transform( - UNUSED_ATTRIBUTE std::shared_ptr input, - std::vector> &transformed, + UNUSED_ATTRIBUTE std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result_plan = std::make_shared(DummyScan::make()); @@ -206,24 +208,24 @@ GetToSeqScan::GetToSeqScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToSeqScan::Check(std::shared_ptr plan, +bool GetToSeqScan::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); return get->table != nullptr; } void GetToSeqScan::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalGet *get = input->Op().As(); + const LogicalGet *get = input->Node()->As(); auto result_plan = std::make_shared( PhysicalSeqScan::make(get->get_id, get->table, get->table_alias, get->predicates, get->is_for_update)); - UNUSED_ATTRIBUTE std::vector> children = + UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); @@ -238,12 +240,12 @@ GetToIndexScan::GetToIndexScan() { match_pattern = std::make_shared(OpType::Get); } -bool GetToIndexScan::Check(std::shared_ptr plan, +bool GetToIndexScan::Check(std::shared_ptr plan, OptimizeContext *context) const { // If there is a index for the table, return true, // else return false (void)context; - const LogicalGet *get = plan->Op().As(); + const LogicalGet *get = plan->Node()->As(); bool index_exist = false; if (get != nullptr && get->table != nullptr && !get->table->GetIndexCatalogEntries().empty()) { @@ -253,14 +255,14 @@ bool GetToIndexScan::Check(std::shared_ptr plan, } void GetToIndexScan::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - UNUSED_ATTRIBUTE std::vector> children = + UNUSED_ATTRIBUTE std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 0); - const LogicalGet *get = input->Op().As(); + const LogicalGet *get = input->Node()->As(); // Get sort columns if they are all base columns and all in asc order auto sort = context->required_prop->GetPropertyOfType(PropertyType::SORT); @@ -415,17 +417,17 @@ LogicalQueryDerivedGetToPhysical::LogicalQueryDerivedGetToPhysical() { } bool LogicalQueryDerivedGetToPhysical::Check( - std::shared_ptr expr, OptimizeContext *context) const { + std::shared_ptr expr, OptimizeContext *context) const { (void)context; (void)expr; return true; } void LogicalQueryDerivedGetToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalQueryDerivedGet *get = input->Op().As(); + const LogicalQueryDerivedGet *get = input->Node()->As(); auto result_plan = std::make_shared(QueryDerivedScan::make( @@ -443,16 +445,16 @@ LogicalExternalFileGetToPhysical::LogicalExternalFileGetToPhysical() { } bool LogicalExternalFileGetToPhysical::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExternalFileGetToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const auto *get = input->Op().As(); + const auto *get = input->Node()->As(); auto result_plan = std::make_shared( ExternalFileScan::make(get->get_id, get->format, get->file_name, @@ -472,7 +474,7 @@ LogicalDeleteToPhysical::LogicalDeleteToPhysical() { match_pattern->AddChild(child); } -bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, +bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -480,10 +482,10 @@ bool LogicalDeleteToPhysical::Check(std::shared_ptr plan, } void LogicalDeleteToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalDelete *delete_op = input->Op().As(); + const LogicalDelete *delete_op = input->Node()->As(); auto result = std::make_shared( PhysicalDelete::make(delete_op->target_table)); PELOTON_ASSERT(input->Children().size() == 1); @@ -500,7 +502,7 @@ LogicalUpdateToPhysical::LogicalUpdateToPhysical() { match_pattern->AddChild(child); } -bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, +bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -508,10 +510,10 @@ bool LogicalUpdateToPhysical::Check(std::shared_ptr plan, } void LogicalUpdateToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalUpdate *update_op = input->Op().As(); + const LogicalUpdate *update_op = input->Node()->As(); auto result = std::make_shared( PhysicalUpdate::make(update_op->target_table, update_op->updates)); PELOTON_ASSERT(input->Children().size() != 0); @@ -528,7 +530,7 @@ LogicalInsertToPhysical::LogicalInsertToPhysical() { // match_pattern->AddChild(child); } -bool LogicalInsertToPhysical::Check(std::shared_ptr plan, +bool LogicalInsertToPhysical::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; @@ -536,10 +538,11 @@ bool LogicalInsertToPhysical::Check(std::shared_ptr plan, } void LogicalInsertToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalInsert *insert_op = input->Op().As(); + const LogicalInsert *insert_op = input->Node()->As(); + std::cout << insert_op << " val" << std::endl; auto result = std::make_shared(PhysicalInsert::make( insert_op->target_table, insert_op->columns, insert_op->values)); PELOTON_ASSERT(input->Children().size() == 0); @@ -557,17 +560,17 @@ LogicalInsertSelectToPhysical::LogicalInsertSelectToPhysical() { } bool LogicalInsertSelectToPhysical::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)plan; (void)context; return true; } void LogicalInsertSelectToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const LogicalInsertSelect *insert_op = input->Op().As(); + const LogicalInsertSelect *insert_op = input->Node()->As(); auto result = std::make_shared( PhysicalInsertSelect::make(insert_op->target_table)); PELOTON_ASSERT(input->Children().size() == 1); @@ -585,20 +588,20 @@ LogicalGroupByToHashGroupBy::LogicalGroupByToHashGroupBy() { } bool LogicalGroupByToHashGroupBy::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = - plan->Op().As(); + plan->Node()->As(); return !agg_op->columns.empty(); } void LogicalGroupByToHashGroupBy::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { const LogicalAggregateAndGroupBy *agg_op = - input->Op().As(); + input->Node()->As(); auto result = std::make_shared( PhysicalHashGroupBy::make(agg_op->columns, agg_op->having)); PELOTON_ASSERT(input->Children().size() == 1); @@ -616,17 +619,17 @@ LogicalAggregateToPhysical::LogicalAggregateToPhysical() { } bool LogicalAggregateToPhysical::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, OptimizeContext *context) const { (void)context; const LogicalAggregateAndGroupBy *agg_op = - plan->Op().As(); + plan->Node()->As(); return agg_op->columns.empty(); } void LogicalAggregateToPhysical::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto result = std::make_shared(PhysicalAggregate::make()); PELOTON_ASSERT(input->Children().size() == 1); @@ -653,7 +656,7 @@ InnerJoinToInnerNLJoin::InnerJoinToInnerNLJoin() { return; } -bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, +bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -661,16 +664,16 @@ bool InnerJoinToInnerNLJoin::Check(std::shared_ptr plan, } void InnerJoinToInnerNLJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + const LogicalInnerJoin *inner_join = input->Node()->As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; + auto left_group_id = children[0]->Node()->As()->origin_group; + auto right_group_id = children[1]->Node()->As()->origin_group; auto &left_group_alias = context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); auto &right_group_alias = @@ -714,7 +717,7 @@ InnerJoinToInnerHashJoin::InnerJoinToInnerHashJoin() { return; } -bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, +bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -722,16 +725,16 @@ bool InnerJoinToInnerHashJoin::Check(std::shared_ptr plan, } void InnerJoinToInnerHashJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { // first build an expression representing hash join - const LogicalInnerJoin *inner_join = input->Op().As(); + const LogicalInnerJoin *inner_join = input->Node()->As(); auto children = input->Children(); PELOTON_ASSERT(children.size() == 2); - auto left_group_id = children[0]->Op().As()->origin_group; - auto right_group_id = children[1]->Op().As()->origin_group; + auto left_group_id = children[0]->Node()->As()->origin_group; + auto right_group_id = children[1]->Node()->As()->origin_group; auto &left_group_alias = context->metadata->memo.GetGroupByID(left_group_id)->GetTableAliases(); auto &right_group_alias = @@ -765,7 +768,7 @@ ImplementDistinct::ImplementDistinct() { match_pattern->AddChild(std::make_shared(OpType::Leaf)); } -bool ImplementDistinct::Check(std::shared_ptr plan, +bool ImplementDistinct::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -773,13 +776,13 @@ bool ImplementDistinct::Check(std::shared_ptr plan, } void ImplementDistinct::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { (void)context; auto result_plan = std::make_shared(PhysicalDistinct::make()); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -796,7 +799,7 @@ ImplementLimit::ImplementLimit() { match_pattern->AddChild(std::make_shared(OpType::Leaf)); } -bool ImplementLimit::Check(std::shared_ptr plan, +bool ImplementLimit::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -804,16 +807,16 @@ bool ImplementLimit::Check(std::shared_ptr plan, } void ImplementLimit::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, OptimizeContext *context) const { (void)context; - const LogicalLimit *limit_op = input->Op().As(); + const LogicalLimit *limit_op = input->Node()->As(); auto result_plan = std::make_shared( PhysicalLimit::make(limit_op->offset, limit_op->limit, limit_op->sort_exprs, limit_op->sort_ascending)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -830,23 +833,23 @@ LogicalExportToPhysicalExport::LogicalExportToPhysicalExport() { } bool LogicalExportToPhysicalExport::Check( - UNUSED_ATTRIBUTE std::shared_ptr plan, + UNUSED_ATTRIBUTE std::shared_ptr plan, UNUSED_ATTRIBUTE OptimizeContext *context) const { return true; } void LogicalExportToPhysicalExport::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - const auto *export_op = input->Op().As(); + const auto *export_op = input->Node()->As(); auto result_plan = std::make_shared(PhysicalExportExternalFile::make( export_op->format, export_op->file_name, export_op->delimiter, export_op->quote, export_op->escape)); - std::vector> children = input->Children(); + std::vector> children = input->Children(); PELOTON_ASSERT(children.size() == 1); result_plan->PushChild(children[0]); @@ -874,27 +877,27 @@ PushFilterThroughJoin::PushFilterThroughJoin() { match_pattern->AddChild(child); } -bool PushFilterThroughJoin::Check(std::shared_ptr, +bool PushFilterThroughJoin::Check(std::shared_ptr, OptimizeContext *) const { return true; } void PushFilterThroughJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughJoin::Transform"); auto &memo = context->metadata->memo; auto join_op_expr = input->Children().at(0); auto &join_children = join_op_expr->Children(); - auto left_group_id = join_children[0]->Op().As()->origin_group; - auto right_group_id = join_children[1]->Op().As()->origin_group; + auto left_group_id = join_children[0]->Node()->As()->origin_group; + auto right_group_id = join_children[1]->Node()->As()->origin_group; const auto &left_group_aliases_set = memo.GetGroupByID(left_group_id)->GetTableAliases(); const auto &right_group_aliases_set = memo.GetGroupByID(right_group_id)->GetTableAliases(); - auto &predicates = input->Op().As()->predicates; + auto &predicates = input->Node()->As()->predicates; std::vector left_predicates; std::vector right_predicates; std::vector join_predicates; @@ -918,10 +921,10 @@ void PushFilterThroughJoin::Transform( // Construct join operator auto pre_join_predicate = - join_op_expr->Op().As()->join_predicates; + join_op_expr->Node()->As()->join_predicates; join_predicates.insert(join_predicates.end(), pre_join_predicate.begin(), pre_join_predicate.end()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalInnerJoin::make(join_predicates)); @@ -966,20 +969,20 @@ PushFilterThroughAggregation::PushFilterThroughAggregation() { match_pattern->AddChild(child); } -bool PushFilterThroughAggregation::Check(std::shared_ptr, +bool PushFilterThroughAggregation::Check(std::shared_ptr, OptimizeContext *) const { return true; } void PushFilterThroughAggregation::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PushFilterThroughAggregation::Transform"); auto aggregation_op = - input->Children().at(0)->Op().As(); + input->Children().at(0)->Node()->As(); - auto &predicates = input->Op().As()->predicates; + auto &predicates = input->Node()->As()->predicates; std::vector embedded_predicates; std::vector pushdown_predicates; @@ -1000,7 +1003,7 @@ void PushFilterThroughAggregation::Transform( embedded_predicates.emplace_back(predicate); } auto groupby_cols = aggregation_op->columns; - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalAggregateAndGroupBy::make(groupby_cols, embedded_predicates)); @@ -1030,7 +1033,7 @@ CombineConsecutiveFilter::CombineConsecutiveFilter() { match_pattern->AddChild(child); } -bool CombineConsecutiveFilter::Check(std::shared_ptr plan, +bool CombineConsecutiveFilter::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1046,18 +1049,18 @@ bool CombineConsecutiveFilter::Check(std::shared_ptr plan, } void CombineConsecutiveFilter::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { auto child_filter = input->Children()[0]; - auto root_predicates = input->Op().As()->predicates; - auto &child_predicates = child_filter->Op().As()->predicates; + auto root_predicates = input->Node()->As()->predicates; + auto &child_predicates = child_filter->Node()->As()->predicates; root_predicates.insert(root_predicates.end(), child_predicates.begin(), child_predicates.end()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalFilter::make(root_predicates)); @@ -1077,7 +1080,7 @@ EmbedFilterIntoGet::EmbedFilterIntoGet() { match_pattern->AddChild(child); } -bool EmbedFilterIntoGet::Check(std::shared_ptr plan, +bool EmbedFilterIntoGet::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1085,14 +1088,14 @@ bool EmbedFilterIntoGet::Check(std::shared_ptr plan, } void EmbedFilterIntoGet::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { - auto get = input->Children()[0]->Op().As(); + auto get = input->Children()[0]->Node()->As(); - auto predicates = input->Op().As()->predicates; + auto predicates = input->Node()->As()->predicates; - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalGet::make(get->get_id, predicates, get->table, get->table_alias, get->is_for_update)); @@ -1113,15 +1116,15 @@ MarkJoinToInnerJoin::MarkJoinToInnerJoin() { int MarkJoinToInnerJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); } -bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, +bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1133,16 +1136,16 @@ bool MarkJoinToInnerJoin::Check(std::shared_ptr plan, } void MarkJoinToInnerJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("MarkJoinToInnerJoin::Transform"); - UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); + UNUSED_ATTRIBUTE auto mark_join = input->Node()->As(); auto &join_children = input->Children(); PELOTON_ASSERT(mark_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(LogicalInnerJoin::make()); output->PushChild(join_children[0]); @@ -1164,15 +1167,15 @@ SingleJoinToInnerJoin::SingleJoinToInnerJoin() { int SingleJoinToInnerJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::Low); } -bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, +bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1184,16 +1187,16 @@ bool SingleJoinToInnerJoin::Check(std::shared_ptr plan, } void SingleJoinToInnerJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("SingleJoinToInnerJoin::Transform"); - UNUSED_ATTRIBUTE auto single_join = input->Op().As(); + UNUSED_ATTRIBUTE auto single_join = input->Node()->As(); auto &join_children = input->Children(); PELOTON_ASSERT(single_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(LogicalInnerJoin::make()); output->PushChild(join_children[0]); @@ -1217,15 +1220,15 @@ PullFilterThroughMarkJoin::PullFilterThroughMarkJoin() { int PullFilterThroughMarkJoin::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); } -bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, +bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1239,22 +1242,22 @@ bool PullFilterThroughMarkJoin::Check(std::shared_ptr plan, } void PullFilterThroughMarkJoin::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughMarkJoin::Transform"); - UNUSED_ATTRIBUTE auto mark_join = input->Op().As(); + UNUSED_ATTRIBUTE auto mark_join = input->Node()->As(); auto &join_children = input->Children(); - auto filter = join_children[1]->Op(); + auto filter = join_children[1]->Node(); auto &filter_children = join_children[1]->Children(); PELOTON_ASSERT(mark_join->join_predicates.empty()); - std::shared_ptr output = + std::shared_ptr output = std::make_shared(filter); - std::shared_ptr join = - std::make_shared(input->Op()); + std::shared_ptr join = + std::make_shared(input->Node()); join->PushChild(join_children[0]); join->PushChild(filter_children[0]); @@ -1278,16 +1281,16 @@ PullFilterThroughAggregation::PullFilterThroughAggregation() { int PullFilterThroughAggregation::Promise(GroupExpression *group_expr, OptimizeContext *context) const { (void)context; - auto root_type = match_pattern->Type(); + auto root_type = match_pattern->GetOpType(); // This rule is not applicable - if (root_type != OpType::Leaf && root_type != group_expr->Op().GetType()) { + if (root_type != OpType::Leaf && root_type != group_expr->Node()->GetOpType()) { return 0; } return static_cast(UnnestPromise::High); } bool PullFilterThroughAggregation::Check( - std::shared_ptr plan, OptimizeContext *context) const { + std::shared_ptr plan, OptimizeContext *context) const { (void)context; (void)plan; @@ -1300,18 +1303,18 @@ bool PullFilterThroughAggregation::Check( } void PullFilterThroughAggregation::Transform( - std::shared_ptr input, - std::vector> &transformed, + std::shared_ptr input, + std::vector> &transformed, UNUSED_ATTRIBUTE OptimizeContext *context) const { LOG_TRACE("PullFilterThroughAggregation::Transform"); auto &memo = context->metadata->memo; auto &filter_expr = input->Children()[0]; auto child_group_id = - filter_expr->Children()[0]->Op().As()->origin_group; + filter_expr->Children()[0]->Node()->As()->origin_group; const auto &child_group_aliases_set = memo.GetGroupByID(child_group_id)->GetTableAliases(); - auto &predicates = filter_expr->Op().As()->predicates; + auto &predicates = filter_expr->Node()->As()->predicates; std::vector correlated_predicates; std::vector normal_predicates; @@ -1337,15 +1340,15 @@ void PullFilterThroughAggregation::Transform( return; } - auto aggregation = input->Op().As(); + auto aggregation = input->Node()->As(); for (auto &col : aggregation->columns) { new_groupby_cols.emplace_back(col->Copy()); } - std::shared_ptr output = + std::shared_ptr output = std::make_shared( LogicalFilter::make(correlated_predicates)); std::vector new_having(aggregation->having); - std::shared_ptr new_aggregation = + std::shared_ptr new_aggregation = std::make_shared( LogicalAggregateAndGroupBy::make(new_groupby_cols, new_having)); output->PushChild(new_aggregation); @@ -1353,7 +1356,7 @@ void PullFilterThroughAggregation::Transform( // Construct child filter if any if (!normal_predicates.empty()) { - std::shared_ptr new_filter = + std::shared_ptr new_filter = std::make_shared( LogicalFilter::make(normal_predicates)); new_aggregation->PushChild(new_filter); diff --git a/src/optimizer/rule_rewrite.cpp b/src/optimizer/rule_rewrite.cpp new file mode 100644 index 00000000000..c1439a847e3 --- /dev/null +++ b/src/optimizer/rule_rewrite.cpp @@ -0,0 +1,626 @@ +#include + +#include "catalog/column_catalog.h" +#include "catalog/index_catalog.h" +#include "catalog/table_catalog.h" +#include "expression/group_marker_expression.h" +#include "optimizer/operators.h" +#include "optimizer/absexpr_expression.h" +#include "optimizer/optimizer_metadata.h" +#include "optimizer/properties.h" +#include "optimizer/rule_rewrite.h" +#include "optimizer/util.h" +#include "type/value_factory.h" + +namespace peloton { +namespace optimizer { + +// =========================================================== +// +// ComparatorElimination related functions +// +// =========================================================== +ComparatorElimination::ComparatorElimination(RuleType rule, ExpressionType root) { + type_ = rule; + + auto left = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right = std::make_shared(ExpressionType::VALUE_CONSTANT); + match_pattern = std::make_shared(root); + match_pattern->AddChild(left); + match_pattern->AddChild(right); +} + +int ComparatorElimination::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(MEDIUM_PRIORITY); +} + +bool ComparatorElimination::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + return true; +} + +void ComparatorElimination::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)transformed; + (void)context; + + // Extract the AbstractExpression through indirection layer. + // Since the binding succeeded, there are guaranteed to be two children. + PELOTON_ASSERT(input->Children().size() == 2); + + auto left_abs = std::dynamic_pointer_cast(input->Children()[0]->Node()); + auto right_abs = std::dynamic_pointer_cast(input->Children()[1]->Node()); + PELOTON_ASSERT(left_abs != nullptr && right_abs != nullptr); + + auto left = left_abs->GetExpr(); + auto right = right_abs->GetExpr(); + auto lv = std::dynamic_pointer_cast(left); + auto rv = std::dynamic_pointer_cast(right); + PELOTON_ASSERT(lv != nullptr && rv != nullptr); + + // Get the Value from ConstantValueExpression + auto lvalue = lv->GetValue(); + auto rvalue = rv->GetValue(); + if (lvalue.CheckComparable(rvalue)) { + CmpBool cmp = CmpBool::CmpTrue; + switch (type_) { + case RuleType::CONSTANT_COMPARE_EQUAL: + cmp = lvalue.CompareEquals(rvalue); + break; + case RuleType::CONSTANT_COMPARE_NOTEQUAL: + cmp = lvalue.CompareNotEquals(rvalue); + break; + case RuleType::CONSTANT_COMPARE_LESSTHAN: + cmp = lvalue.CompareLessThan(rvalue); + break; + case RuleType::CONSTANT_COMPARE_GREATERTHAN: + cmp = lvalue.CompareGreaterThan(rvalue); + break; + case RuleType::CONSTANT_COMPARE_LESSTHANOREQUALTO: + // lv <= rv does not have a predefined function so we do the following: + // (1) Compute truth value of lvalue > rvalue + // (2) Flip the truth value unless CmpBool::NULL_ + cmp = lvalue.CompareGreaterThan(rvalue); + if (cmp != CmpBool::NULL_) { + cmp = (cmp == CmpBool::CmpFalse) ? CmpBool::CmpTrue : CmpBool::CmpFalse; + } + break; + case RuleType::CONSTANT_COMPARE_GREATERTHANOREQUALTO: + cmp = lvalue.CompareGreaterThanEquals(rvalue); + break; + default: + // Other type_ should not be handled by this rule + int type = static_cast(type_); + LOG_ERROR("lvalue compare rvalue with RuleType (%d) not implemented", type); + PELOTON_ASSERT(0); + break; + } + + // Create the replacement + type::Value val = type::ValueFactory::GetBooleanValue(cmp); + auto expr = std::make_shared(val); + auto container = std::make_shared(AbsExprNode(expr)); + auto shared = std::make_shared(container); + transformed.push_back(shared); + } + + // If the values cannot be comparable, we leave them as is. + // We don't throw an error or anything because it is possible this branch + // may be collapsed due to subsequent optimizations, and it is likely + // any error will be caught during actual query execution. + return; +} + +// =========================================================== +// +// EquivalentTransform related functions +// +// =========================================================== +EquivalentTransform::EquivalentTransform(RuleType rule, ExpressionType root) { + type_ = rule; + + auto left = std::make_shared(ExpressionType::GROUP_MARKER); + auto right = std::make_shared(ExpressionType::GROUP_MARKER); + match_pattern = std::make_shared(root); + match_pattern->AddChild(left); + match_pattern->AddChild(right); +} + +int EquivalentTransform::Promise(GroupExpression *group_expr, + OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(HIGH_PRIORITY); +} + +bool EquivalentTransform::Check(std::shared_ptr plan, + OptimizeContext *context) const { + (void)context; + (void)plan; + return true; +} + +void EquivalentTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)transformed; + (void)context; + + // We expect EquivalentTransform to operate on AND / OR which should + // have exactly 2 children for the expression to logically make sense. + PELOTON_ASSERT(input->Children().size() == 2); + + // Create flipped ordering + auto left = input->Children()[0]; + auto right = input->Children()[1]; + + // The children do not strictly matter anymore + auto type = match_pattern->GetExpType(); + auto expr = std::make_shared(type); + auto a_expr = std::make_shared(expr); + auto shared = std::make_shared(a_expr); + + // Create flipped ordering at logical level + shared->PushChild(right); + shared->PushChild(left); + transformed.push_back(shared); + return; +} + + +// =========================================================== +// +// Transitive-Transform related functions +// +// =========================================================== +TVEqualityWithTwoCVTransform::TVEqualityWithTwoCVTransform() { + type_ = RuleType::TV_EQUALITY_WITH_TWO_CV; + + // (A.B = x) AND (A.B = y) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + + auto l_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto l_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto l_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + l_eq->AddChild(l_left); + l_eq->AddChild(l_right); + + auto r_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto r_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto r_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + r_eq->AddChild(r_left); + r_eq->AddChild(r_right); + + match_pattern->AddChild(l_eq); + match_pattern->AddChild(r_eq); +} + +int TVEqualityWithTwoCVTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(LOW_PRIORITY); +} + +bool TVEqualityWithTwoCVTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void TVEqualityWithTwoCVTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + //TODO(wz2): TVEqualityWithTwoCVTransform should work beyond straight equality + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (A.B = x) AND (A.B = y) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + auto l_eq = input->Children()[0]; + auto r_eq = input->Children()[1]; + PELOTON_ASSERT(l_eq->Children().size() == 2); + PELOTON_ASSERT(r_eq->Children().size() == 2); + PELOTON_ASSERT(l_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + PELOTON_ASSERT(r_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + + auto l_tv = l_eq->Children()[0]; + auto l_cv = l_eq->Children()[1]; + PELOTON_ASSERT(l_tv->Children().size() == 0); + PELOTON_ASSERT(l_cv->Children().size() == 0); + PELOTON_ASSERT(l_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(l_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto r_tv = r_eq->Children()[0]; + auto r_cv = r_eq->Children()[1]; + PELOTON_ASSERT(r_tv->Children().size() == 0); + PELOTON_ASSERT(r_cv->Children().size() == 0); + PELOTON_ASSERT(r_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(r_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_c = std::dynamic_pointer_cast(r_tv->Node()); + auto l_cv_c = std::dynamic_pointer_cast(l_cv->Node()); + auto r_cv_c = std::dynamic_pointer_cast(r_cv->Node()); + PELOTON_ASSERT(l_tv_c != nullptr && r_tv_c != nullptr); + PELOTON_ASSERT(l_cv_c != nullptr && r_cv_c != nullptr); + + auto l_tv_expr = l_tv_c->GetExpr(); + auto r_tv_expr = r_tv_c->GetExpr(); + if (l_tv_expr->ExactlyEquals(*r_tv_expr)) { + // Given the pattern (A.B = x) AND (C.D = y), the IF statement asserts that A.B is the same as C.D + // TODO(wz2): ExactlyEquals may be too strict, since must match bound_oid, table_name, col_name + + // ExactlyEqual on TupleValueExpression has sufficient check + auto l_cv_expr = std::dynamic_pointer_cast(l_cv_c->GetExpr()); + auto r_cv_expr = std::dynamic_pointer_cast(r_cv_c->GetExpr()); + auto l_cv_val = l_cv_expr->GetValue(); + auto r_cv_val = r_cv_expr->GetValue(); + if (l_cv_val.CheckComparable(r_cv_val)) { + // Given a pattern (A.B = x) AND (A.B = y), we perform the following: + // - Rewrite expression to (A.B = x) if x == y + // - Rewrite expression to FALSE if x != y (including if x / y is NULL) + bool is_eq = false; + if (l_cv_val.CompareEquals(r_cv_val) == CmpBool::CmpTrue) { + // This means in the pattern (A.B = x) AND (A.B = y) that x == y + is_eq = true; + } + + if (is_eq) { + transformed.push_back(l_eq); + } else { + auto val = type::ValueFactory::GetBooleanValue(false); + auto constant = std::make_shared(val); + auto abs_expr = std::make_shared(std::make_shared(AbsExprNode(constant))); + transformed.push_back(abs_expr); + } + + return; + } + } + + return; +} + +TransitiveClosureConstantTransform::TransitiveClosureConstantTransform() { + type_ = RuleType::TRANSITIVE_CLOSURE_CONSTANT; + + // (A.B = x) AND (A.B = C.D) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + + auto l_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto l_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto l_right = std::make_shared(ExpressionType::VALUE_CONSTANT); + l_eq->AddChild(l_left); + l_eq->AddChild(l_right); + + auto r_eq = std::make_shared(ExpressionType::COMPARE_EQUAL); + auto r_left = std::make_shared(ExpressionType::VALUE_TUPLE); + auto r_right = std::make_shared(ExpressionType::VALUE_TUPLE); + r_eq->AddChild(r_left); + r_eq->AddChild(r_right); + + match_pattern->AddChild(l_eq); + match_pattern->AddChild(r_eq); +} + +int TransitiveClosureConstantTransform::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(LOW_PRIORITY); +} + +bool TransitiveClosureConstantTransform::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void TransitiveClosureConstantTransform::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + //TODO(wz2): TransitiveClosureConstant should work beyond straight equality + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (A.B = x) AND (A.B = C.D) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + auto l_eq = input->Children()[0]; + auto r_eq = input->Children()[1]; + PELOTON_ASSERT(l_eq->Children().size() == 2); + PELOTON_ASSERT(r_eq->Children().size() == 2); + PELOTON_ASSERT(l_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + PELOTON_ASSERT(r_eq->Node()->GetExpType() == ExpressionType::COMPARE_EQUAL); + + auto l_tv = l_eq->Children()[0]; + auto l_cv = l_eq->Children()[1]; + PELOTON_ASSERT(l_tv->Children().size() == 0); + PELOTON_ASSERT(l_cv->Children().size() == 0); + PELOTON_ASSERT(l_tv->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(l_cv->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + auto r_tv_l = r_eq->Children()[0]; + auto r_tv_r = r_eq->Children()[1]; + PELOTON_ASSERT(r_tv_l->Children().size() == 0); + PELOTON_ASSERT(r_tv_r->Children().size() == 0); + PELOTON_ASSERT(r_tv_l->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + PELOTON_ASSERT(r_tv_r->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + auto l_tv_c = std::dynamic_pointer_cast(l_tv->Node()); + auto r_tv_l_c = std::dynamic_pointer_cast(r_tv_l->Node()); + auto r_tv_r_c = std::dynamic_pointer_cast(r_tv_r->Node()); + PELOTON_ASSERT(l_tv_c != nullptr && r_tv_l_c != nullptr && r_tv_r_c != nullptr); + + auto l_tv_expr = l_tv_c->GetExpr(); + auto r_tv_l_expr = r_tv_l_c->GetExpr(); + auto r_tv_r_expr = r_tv_r_c->GetExpr(); + + // At this stage, we have the arbitrary structure: (A.B = x) AND (C.D = E.F) + // TODO(wz2): ExactlyEquals for TupleValue may be too strict, since must match bound_oid, table_name, col_name + if (r_tv_l_expr->ExactlyEquals(*r_tv_r_expr)) { + // Handles case where C.D = E.F, which can rewrite to just A.B = x + transformed.push_back(l_eq); + return; + } + + if (!l_tv_expr->ExactlyEquals(*r_tv_l_expr) && !l_tv_expr->ExactlyEquals(*r_tv_r_expr)) { + // We know that A.B != C.D and A.B != E.F, so no optimization possible + return; + } + + auto new_left_eq = l_eq; + auto right_val_copy = l_cv; + auto new_right_eq = std::make_shared(r_eq->Node()); + + // At this stage, we have knowledge that C.D != E.F + if (l_tv_expr->ExactlyEquals(*r_tv_l_expr)) { + // At this stage, we have knowledge that A.B = C.D + new_right_eq->PushChild(right_val_copy); + new_right_eq->PushChild(r_tv_r); + } else { + // At this stage, we have knowledge that A.B = E.F + new_right_eq->PushChild(r_tv_l); + new_right_eq->PushChild(right_val_copy); + } + + // Create new root expression + auto abs_expr = std::make_shared(input->Node()); + abs_expr->PushChild(new_left_eq); + abs_expr->PushChild(new_right_eq); + transformed.push_back(abs_expr); +} + +// =========================================================== +// +// Boolean short-circuit related functions +// +// =========================================================== +AndShortCircuit::AndShortCircuit() { + type_ = RuleType::AND_SHORT_CIRCUIT; + + // (FALSE AND ) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_AND); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int AndShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(HIGH_PRIORITY); +} + +bool AndShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void AndShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (FALSE AND ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_AND); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + LOG_DEBUG("fjdsklafjksdjflkadsjf"); + + // Only transform the expression if we're ANDing a FALSE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsFalse()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + + +OrShortCircuit::OrShortCircuit() { + type_ = RuleType::OR_SHORT_CIRCUIT; + + // (FALSE AND ) + match_pattern = std::make_shared(ExpressionType::CONJUNCTION_OR); + auto left_child = std::make_shared(ExpressionType::VALUE_CONSTANT); + auto right_child = std::make_shared(ExpressionType::GROUP_MARKER); + + match_pattern->AddChild(left_child); + match_pattern->AddChild(right_child); +} + +int OrShortCircuit::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(HIGH_PRIORITY); +} + +bool OrShortCircuit::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void OrShortCircuit::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 2); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::CONJUNCTION_OR); + + std::shared_ptr left = input->Children()[0]; + PELOTON_ASSERT(left->Children().size() == 0); + PELOTON_ASSERT(left->Node()->GetExpType() == ExpressionType::VALUE_CONSTANT); + + std::shared_ptr left_c = std::dynamic_pointer_cast(left->Node()); + PELOTON_ASSERT(left_c != nullptr); + + auto left_cv_expr = std::dynamic_pointer_cast(left_c->GetExpr()); + type::Value left_value = left_cv_expr->GetValue(); + + // Only transform the expression if we're ANDing a TRUE boolean value + if (left_value.GetTypeId() == type::TypeId::BOOLEAN && left_value.IsTrue()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + + +NullLookupOnNotNullColumn::NullLookupOnNotNullColumn() { + type_ = RuleType::NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(LOW_PRIORITY); +} + +bool NullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + PELOTON_ASSERT(child_c != nullptr); + + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [FALSE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + std::shared_ptr false_expr = std::make_shared(val_false); + std::shared_ptr false_cnt = std::make_shared(false_expr); + std::shared_ptr false_container = std::make_shared(false_cnt); + transformed.push_back(false_container); + } +} + +NotNullLookupOnNotNullColumn::NotNullLookupOnNotNullColumn() { + type_ = RuleType::NOT_NULL_LOOKUP_ON_NOT_NULL_COLUMN; + + // Structure: [T.X IS NOT NULL] + match_pattern = std::make_shared(ExpressionType::OPERATOR_IS_NOT_NULL); + auto child = std::make_shared(ExpressionType::VALUE_TUPLE); + + match_pattern->AddChild(child); +} + +int NotNullLookupOnNotNullColumn::Promise(GroupExpression *group_expr, OptimizeContext *context) const { + (void)group_expr; + (void)context; + return static_cast(LOW_PRIORITY); +} + +bool NotNullLookupOnNotNullColumn::Check(std::shared_ptr plan, OptimizeContext *context) const { + (void)plan; + (void)context; + return true; +} + +void NotNullLookupOnNotNullColumn::Transform(std::shared_ptr input, + std::vector> &transformed, + OptimizeContext *context) const { + (void)context; + (void)transformed; + + // Asserting guarantees provided by the GroupExprBindingIterator + // Structure: (TRUE OR ) + PELOTON_ASSERT(input->Children().size() == 1); + PELOTON_ASSERT(input->Node()->GetExpType() == ExpressionType::OPERATOR_IS_NOT_NULL); + + std::shared_ptr child = input->Children()[0]; + PELOTON_ASSERT(child->Children().size() == 0); + PELOTON_ASSERT(child->Node()->GetExpType() == ExpressionType::VALUE_TUPLE); + + std::shared_ptr child_c = std::dynamic_pointer_cast(child->Node()); + auto tuple_expr = std::dynamic_pointer_cast(child_c->GetExpr()); + + // Only transform into [TRUE] if the tuple value expression is specifically non-NULL, + // otherwise do nothing + if (tuple_expr->GetIsNotNull()) { + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + std::shared_ptr true_expr = std::make_shared(val_true); + std::shared_ptr true_cnt = std::make_shared(true_expr); + std::shared_ptr true_container = std::make_shared(true_cnt); + transformed.push_back(true_container); + } +} + +} // namespace optimizer +} // namespace peloton diff --git a/src/optimizer/stats/child_stats_deriver.cpp b/src/optimizer/stats/child_stats_deriver.cpp index d320547915c..2dc50a22e9d 100644 --- a/src/optimizer/stats/child_stats_deriver.cpp +++ b/src/optimizer/stats/child_stats_deriver.cpp @@ -27,7 +27,7 @@ vector ChildStatsDeriver::DeriveInputStats(GroupExpression *gexpr, gexpr_ = gexpr; memo_ = memo; output_ = vector(gexpr->GetChildrenGroupsSize(), ExprSet{}); - gexpr->Op().Accept(this); + gexpr->Node()->Accept(this); return std::move(output_); } diff --git a/src/optimizer/stats/stats_calculator.cpp b/src/optimizer/stats/stats_calculator.cpp index d086938a817..0e4e09a573b 100644 --- a/src/optimizer/stats/stats_calculator.cpp +++ b/src/optimizer/stats/stats_calculator.cpp @@ -33,7 +33,7 @@ void StatsCalculator::CalculateStats(GroupExpression *gexpr, memo_ = memo; required_cols_ = required_cols; txn_ = txn; - gexpr->Op().Accept(this); + gexpr->Node()->Accept(this); } void StatsCalculator::Visit(const LogicalGet *op) { diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index c6f8df81fda..31dab68579f 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -317,6 +317,24 @@ std::shared_ptr TrafficCop::PrepareStatement( tcop_txn_state_.top().first, default_database_name_); bind_node_visitor.BindNameToNode( statement->GetStmtParseTreeList()->GetStatement(0)); + + // Apply the rewriter if possible on the top statement + // TODO(): Rewrite more; move into optimizer and rewrite after unnesting + auto top_stmt = statement->GetStmtParseTreeList()->GetStatement(0); + switch (top_stmt->GetType()) { + case StatementType::SELECT: { + auto select = dynamic_cast(top_stmt); + + expression::AbstractExpression *where = select->where_clause.get(); + auto optimal = rewriter_.RewriteExpression(where); + select->UpdateWhereClause(optimal); + rewriter_.Reset(); + break; + } + default: + break; + } + auto plan = optimizer_->BuildPelotonPlanTree( statement->GetStmtParseTreeList(), tcop_txn_state_.top().first); statement->SetPlanTree(plan); @@ -563,15 +581,15 @@ ResultType TrafficCop::ExecuteStatement( statement, std::move(param_stats)); } - LOG_TRACE("Execute Statement of name: %s", + LOG_DEBUG("Execute Statement of name: %s", statement->GetStatementName().c_str()); - LOG_TRACE("Execute Statement of query: %s", + LOG_DEBUG("Execute Statement of query: %s", statement->GetQueryString().c_str()); - LOG_TRACE("Execute Statement Plan:\n%s", + LOG_DEBUG("Execute Statement Plan:\n%s", planner::PlanUtil::GetInfo(statement->GetPlanTree().get()).c_str()); - LOG_TRACE("Execute Statement Query Type: %s", + LOG_DEBUG("Execute Statement Query Type: %s", statement->GetQueryTypeString().c_str()); - LOG_TRACE("----QueryType: %d--------", + LOG_DEBUG("----QueryType: %d--------", static_cast(statement->GetQueryType())); try { diff --git a/test/optimizer/absexpr_test.cpp b/test/optimizer/absexpr_test.cpp new file mode 100644 index 00000000000..6c8ccdc917e --- /dev/null +++ b/test/optimizer/absexpr_test.cpp @@ -0,0 +1,460 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// absexpr_test.cpp +// +// Identification: test/optimizer/absexpr_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include "common/harness.h" + +#include "function/functions.h" +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/aggregate_expression.h" +#include "expression/conjunction_expression.h" +#include "expression/subquery_expression.h" +#include "expression/parameter_value_expression.h" +#include "expression/star_expression.h" +#include "expression/case_expression.h" +#include "expression/tuple_value_expression.h" +#include "expression/operator_expression.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "expression/function_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" +#include "parser/postgresparser.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class AbsExprTest : public PelotonTest { + public: + // Returns expression of (Constant = Tuple Value) + expression::AbstractExpression *GetTVEqualCVExpression(std::string col, int val) { + auto val_e = GetConstantExpression(val); + auto tuple = new expression::TupleValueExpression(std::move(col)); + auto cmp = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, val_e, tuple + ); + + return cmp; + } + + // Returns ConstantExpression(val) + expression::AbstractExpression *GetConstantExpression(int val) { + auto value = type::ValueFactory::GetIntegerValue(val); + return new expression::ConstantValueExpression(value); + } +}; + +TEST_F(AbsExprTest, CompareTest) { + std::vector compares = { + ExpressionType::COMPARE_EQUAL, + ExpressionType::COMPARE_NOTEQUAL, + ExpressionType::COMPARE_LESSTHAN, + ExpressionType::COMPARE_GREATERTHAN, + ExpressionType::COMPARE_LESSTHANOREQUALTO, + ExpressionType::COMPARE_GREATERTHANOREQUALTO, + ExpressionType::COMPARE_LIKE, + ExpressionType::COMPARE_NOTLIKE, + ExpressionType::COMPARE_IN, + ExpressionType::COMPARE_DISTINCT_FROM + }; + + auto left = new expression::ParameterValueExpression(0); + auto right = new expression::ParameterValueExpression(1); + for (auto type : compares) { + auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); + AbsExprNode op = AbsExprNode(cmp_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(cmp_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(cmp_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(cmp_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + EXPECT_EQ(*(cmp_expr->GetChild(0)), *(rebuilt->GetChild(0))); + EXPECT_EQ(*(cmp_expr->GetChild(1)), *(rebuilt->GetChild(1))); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->GetValueIdx() == 0 && r_child->GetValueIdx() == 1); + + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, ConjunctionTest) { + std::vector compares = { + ExpressionType::CONJUNCTION_AND, + ExpressionType::CONJUNCTION_OR + }; + + auto tval = type::ValueFactory::GetBooleanValue(true); + auto fval = type::ValueFactory::GetBooleanValue(false); + auto left = new expression::ConstantValueExpression(tval); + auto right = new expression::ConstantValueExpression(fval); + for (auto type : compares) { + auto cmp_expr = std::make_shared(type, left->Copy(), right->Copy()); + AbsExprNode op = AbsExprNode(cmp_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(cmp_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(cmp_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(cmp_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*left) && r_child->ExactlyEquals(*right)); + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, OperatorTest) { + std::vector operators = { + ExpressionType::OPERATOR_PLUS, + ExpressionType::OPERATOR_MINUS, + ExpressionType::OPERATOR_MULTIPLY, + ExpressionType::OPERATOR_DIVIDE, + ExpressionType::OPERATOR_CONCAT, + ExpressionType::OPERATOR_MOD, + }; + + std::vector single_ops = { + ExpressionType::OPERATOR_NOT, + ExpressionType::OPERATOR_IS_NULL, + ExpressionType::OPERATOR_IS_NOT_NULL, + ExpressionType::OPERATOR_EXISTS + }; + + auto left = GetConstantExpression(25); + auto right = GetConstantExpression(30); + for (auto type : operators) { + auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), right->Copy()); + op_expr->DeduceExpressionType(); + + AbsExprNode op = AbsExprNode(op_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy(), right->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(op_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(op_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(op_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + auto r_child = dynamic_cast(rebuilt->GetChild(1)); + EXPECT_TRUE(l_child != nullptr && r_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*(op_expr->GetChild(0)))); + EXPECT_TRUE(r_child->ExactlyEquals(*(op_expr->GetChild(1)))); + EXPECT_TRUE(l_child->ExactlyEquals(*left)); + EXPECT_TRUE(r_child->ExactlyEquals(*right)); + + delete rebuilt; + } + + for (auto type : single_ops) { + auto op_expr = std::make_shared(type, type::TypeId::INTEGER, left->Copy(), nullptr); + op_expr->DeduceExpressionType(); + + AbsExprNode op = AbsExprNode(op_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(op_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(op_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(op_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + + auto l_child = dynamic_cast(rebuilt->GetChild(0)); + EXPECT_TRUE(l_child != nullptr); + EXPECT_TRUE(l_child->ExactlyEquals(*(op_expr->GetChild(0)))); + EXPECT_TRUE(l_child->ExactlyEquals(*left)); + + delete rebuilt; + } + + delete left; + delete right; +} + +TEST_F(AbsExprTest, OperatorUnaryMinusTest) { + auto left = GetConstantExpression(25); + auto unary = std::make_shared(left->Copy()); + + AbsExprNode op = AbsExprNode(unary); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({left->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + EXPECT_EQ(unary->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(unary->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(unary->GetChildrenSize(), rebuilt->GetChildrenSize()); + EXPECT_TRUE(unary->GetChild(0)->ExactlyEquals(*(rebuilt->GetChild(0)))); + EXPECT_TRUE(left->ExactlyEquals(*(rebuilt->GetChild(0)))); + delete rebuilt; + delete left; +} + +TEST_F(AbsExprTest, StarTest) { + auto expr = std::make_shared(); + AbsExprNode op = AbsExprNode(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueConstantTest) { + auto cv_expr = dynamic_cast(GetConstantExpression(721)); + auto expr = std::shared_ptr(cv_expr); + AbsExprNode op = AbsExprNode(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); // this does not check value + EXPECT_EQ(expr->GetValueType(), rebuilt->GetValueType()); + + auto lvalue = expr->GetValue(); + + auto rebuilt_val = dynamic_cast(rebuilt); + auto rvalue = rebuilt_val->GetValue(); + EXPECT_TRUE(lvalue.CheckComparable(expr->GetValue())); + EXPECT_TRUE(lvalue.CheckComparable(rvalue)); + EXPECT_TRUE(lvalue.CompareEquals(expr->GetValue()) == CmpBool::CmpTrue); + EXPECT_TRUE(lvalue.CompareEquals(rvalue) == CmpBool::CmpTrue); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueParameterTest) { + auto expr = std::make_shared(15); + AbsExprNode op = AbsExprNode(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr, *rebuilt); // does not check value_idx_ + + auto rebuilt_val = dynamic_cast(rebuilt); + EXPECT_EQ(expr->GetValueIdx(), rebuilt_val->GetValueIdx()); + delete rebuilt; +} + +TEST_F(AbsExprTest, ValueTupleTest) { + auto expr_col = std::make_shared("col"); + expr_col->SetTupleValueExpressionParams(type::TypeId::INTEGER, 1, 1); + expr_col->SetTableName("tbl"); + + AbsExprNode op = AbsExprNode(expr_col); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + EXPECT_EQ(*expr_col, *rebuilt); // checks tbl_name, col_name + + auto rebuilt_val = dynamic_cast(rebuilt); + EXPECT_EQ(rebuilt_val->GetColumnId(), expr_col->GetColumnId()); + EXPECT_EQ(rebuilt_val->GetTableName(), expr_col->GetTableName()); + EXPECT_EQ(rebuilt_val->GetColumnName(), expr_col->GetColumnName()); + EXPECT_EQ(rebuilt_val->GetTupleId(), expr_col->GetTupleId()); + EXPECT_EQ(rebuilt_val->GetValueType(), expr_col->GetValueType()); + delete rebuilt; +} + +TEST_F(AbsExprTest, AggregateNodeTest) { + std::vector aggregates = { + ExpressionType::AGGREGATE_COUNT, + ExpressionType::AGGREGATE_SUM, + ExpressionType::AGGREGATE_MIN, + ExpressionType::AGGREGATE_MAX, + ExpressionType::AGGREGATE_AVG + }; + + // Generic aggregation + for (auto type : aggregates) { + auto child = new expression::TupleValueExpression("col_a"); + auto agg_expr = std::make_shared(type, true, child->Copy()); + agg_expr->DeduceExpressionType(); + + AbsExprNode op = AbsExprNode(agg_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({child->Copy()}); + EXPECT_TRUE(rebuilt != nullptr); + + rebuilt->DeduceExpressionType(); + EXPECT_EQ(agg_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(agg_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(agg_expr->GetChildrenSize(), rebuilt->GetChildrenSize()); + EXPECT_EQ(*(agg_expr->GetChild(0)), *(rebuilt->GetChild(0))); + + EXPECT_TRUE(agg_expr->distinct_); + EXPECT_TRUE(rebuilt->distinct_); + + delete rebuilt; + delete child; + } + + // COUNT (*) Aggregation + auto child = new expression::StarExpression(); + auto agg_expr = std::make_shared(ExpressionType::AGGREGATE_COUNT, true, child); + + agg_expr->DeduceExpressionType(); + EXPECT_TRUE(agg_expr->GetExpressionType() == ExpressionType::AGGREGATE_COUNT_STAR); + + AbsExprNode op = AbsExprNode(agg_expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + rebuilt->DeduceExpressionType(); + + EXPECT_EQ(agg_expr->GetExpressionType(), rebuilt->GetExpressionType()); + EXPECT_EQ(agg_expr->GetValueType(), rebuilt->GetValueType()); + EXPECT_EQ(agg_expr->distinct_, rebuilt->distinct_); + EXPECT_EQ(rebuilt->GetChildrenSize(), 0); + delete rebuilt; +} + +TEST_F(AbsExprTest, CaseExpressionTest) { + auto where1 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_a", 1)); + auto where2 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_b", 2)); + auto where3 = expression::CaseExpression::AbsExprPtr(GetTVEqualCVExpression("col_c", 3)); + auto def_c = expression::CaseExpression::AbsExprPtr(GetConstantExpression(4)); + + auto res1 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(1)); + auto res2 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(2)); + auto res3 = expression::CaseExpression::AbsExprPtr(GetConstantExpression(3)); + std::vector clauses; + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where1), std::move(res1))); + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where2), std::move(res2))); + clauses.push_back(expression::CaseExpression::WhenClause(std::move(where3), std::move(res3))); + + auto expr = std::make_shared(type::TypeId::INTEGER, clauses, std::move(def_c)); + AbsExprNode op = AbsExprNode(expr); + expression::AbstractExpression *rebuilt = op.CopyWithChildren({}); + + // Checks every clause except for ConstantValue values + EXPECT_EQ(*expr, *rebuilt); + + auto rebuilt_c = dynamic_cast(rebuilt); + EXPECT_TRUE(rebuilt_c->GetWhenClauseSize() == 3); + + // Check each when clause + for (int i = 0; i < 3; i++) { + auto res = rebuilt_c->GetWhenClauseResult(i); + auto res_c = dynamic_cast(res); + EXPECT_TRUE(res_c != nullptr && res_c->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(res_c->GetValue()) == (i + 1)); + + auto cond = rebuilt_c->GetWhenClauseCond(i); + EXPECT_TRUE(cond->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(cond->GetChildrenSize() == 2); + + auto rval = cond->GetChild(0); + EXPECT_TRUE(rval->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto rval_c = dynamic_cast(rval); + EXPECT_TRUE(rval_c != nullptr && rval_c->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(rval_c->GetValue()) == (i + 1)); + } + + // Check default clause + auto def_expr = dynamic_cast(rebuilt_c->GetDefault()); + EXPECT_TRUE(def_expr != nullptr); + EXPECT_TRUE(def_expr->GetValue().GetTypeId() == type::TypeId::INTEGER); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(def_expr->GetValue()) == 4); + delete rebuilt; +} + +TEST_F(AbsExprTest, SubqueryTest) { + std::vector> stmts; + { + auto parser = parser::PostgresParser::GetInstance(); + auto query = "SELECT * from foo"; + std::unique_ptr stmt_list(parser.BuildParseTree(query).release()); + stmts = std::move(stmt_list->statements); + } + + EXPECT_TRUE(stmts.size() == 1); + EXPECT_EQ(stmts[0]->GetType(), StatementType::SELECT); + parser::SelectStatement *sel = dynamic_cast(stmts[0].release()); + EXPECT_TRUE(sel != nullptr); + + auto expr = std::make_shared(); + expr->SetSubSelect(sel); + + AbsExprNode container = AbsExprNode(expr); + expression::AbstractExpression *rebuild = container.CopyWithChildren({}); + + EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); + EXPECT_EQ(rebuild->GetChildrenSize(), expr->GetChildrenSize()); + + auto rebuild_s = dynamic_cast(rebuild); + EXPECT_TRUE(rebuild_s != nullptr); + EXPECT_TRUE(rebuild_s->GetSubSelect() == expr->GetSubSelect()); + delete rebuild; +} + +TEST_F(AbsExprTest, FunctionExpressionTest) { + std::vector child1; + std::vector child2; + std::vector types; + for (int i = 0; i < 10; i++) { + child1.push_back(GetConstantExpression(i)); + child2.push_back(GetConstantExpression(i)); + types.push_back(type::TypeId::INTEGER); + } + + auto func = [](const std::vector &val) { + (void)val; + return type::ValueFactory::GetIntegerValue(5); + }; + function::BuiltInFuncType func_ptr = { static_cast(1), func }; + auto expr = std::make_shared("func", child1); + expr->SetBuiltinFunctionExpressionParameters(func_ptr, type::TypeId::INTEGER, types); + + AbsExprNode container = AbsExprNode(expr); + expression::AbstractExpression* rebuild = container.CopyWithChildren(child2); + + EXPECT_EQ(rebuild->GetExpressionType(), expr->GetExpressionType()); + EXPECT_EQ(rebuild->GetChildrenSize(), expr->GetChildrenSize()); + + auto rebuild_c = dynamic_cast(rebuild); + EXPECT_TRUE(rebuild_c != nullptr); + EXPECT_EQ(rebuild_c->GetFuncName(), expr->GetFuncName()); + EXPECT_EQ(rebuild_c->GetFunc().op_id, expr->GetFunc().op_id); + EXPECT_EQ(rebuild_c->GetFunc().impl, expr->GetFunc().impl); + EXPECT_EQ(rebuild_c->GetArgTypes(), expr->GetArgTypes()); + EXPECT_EQ(rebuild_c->IsUDF(), expr->IsUDF()); + + for (int i = 0; i < 10; i++) { + auto l_child = rebuild_c->GetChild(i); + auto r_child = expr->GetChild(i); + EXPECT_TRUE(l_child != r_child && l_child->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto l_cast = dynamic_cast(l_child); + auto r_cast = dynamic_cast(r_child); + EXPECT_TRUE(l_cast != nullptr && r_cast != nullptr); + EXPECT_TRUE(l_cast->ExactlyEquals(*r_cast)); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(l_cast->GetValue()) == i); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(r_cast->GetValue()) == i); + } + + delete rebuild; +} + +} // namespace test +} // namespace peloton diff --git a/test/optimizer/optimizer_rule_test.cpp b/test/optimizer/optimizer_rule_test.cpp index 23f520596dc..d2b806ef91b 100644 --- a/test/optimizer/optimizer_rule_test.cpp +++ b/test/optimizer/optimizer_rule_test.cpp @@ -61,7 +61,7 @@ TEST_F(OptimizerRuleTests, SimpleCommutativeRuleTest) { EXPECT_TRUE(rule.Check(join, nullptr)); - std::vector> outputs; + std::vector> outputs; rule.Transform(join, outputs, nullptr); EXPECT_EQ(outputs.size(), 1); @@ -139,17 +139,17 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); EXPECT_EQ(1, - parent_join->Op().As()->join_predicates.size()); + parent_join->Node()->As()->join_predicates.size()); EXPECT_EQ(1, parent_join->Children()[0] - ->Op() - .As() + ->Node() + ->As() ->join_predicates.size()); // Setup rule InnerJoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); - std::vector> outputs; + std::vector> outputs; rule.Transform(parent_join, outputs, root_context); EXPECT_EQ(1, outputs.size()); @@ -159,8 +159,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Node()->As(); + auto child_join_op = output_join->Children()[1]->Node()->As(); EXPECT_EQ(2, parent_join_op->join_predicates.size()); EXPECT_EQ(0, child_join_op->join_predicates.size()); delete root_context; @@ -234,17 +234,17 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, parent_join->Children()[0]->Children()[1]); EXPECT_EQ(right_leaf, parent_join->Children()[1]); EXPECT_EQ(2, - parent_join->Op().As()->join_predicates.size()); + parent_join->Node()->As()->join_predicates.size()); EXPECT_EQ(0, parent_join->Children()[0] - ->Op() - .As() + ->Node() + ->As() ->join_predicates.size()); // Setup rule InnerJoinAssociativity rule; EXPECT_TRUE(rule.Check(parent_join, root_context)); - std::vector> outputs; + std::vector> outputs; rule.Transform(parent_join, outputs, root_context); EXPECT_EQ(1, outputs.size()); @@ -254,8 +254,8 @@ TEST_F(OptimizerRuleTests, SimpleAssociativeRuleTest2) { EXPECT_EQ(middle_leaf, output_join->Children()[1]->Children()[0]); EXPECT_EQ(right_leaf, output_join->Children()[1]->Children()[1]); - auto parent_join_op = output_join->Op().As(); - auto child_join_op = output_join->Children()[1]->Op().As(); + auto parent_join_op = output_join->Node()->As(); + auto child_join_op = output_join->Children()[1]->Node()->As(); EXPECT_EQ(1, parent_join_op->join_predicates.size()); EXPECT_EQ(1, child_join_op->join_predicates.size()); delete root_context; diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index f1ffd6add66..c1247baeed6 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -348,7 +348,8 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { std::vector child_groups = {gexpr->GetGroupID()}; std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::make_shared( + std::make_shared(Operator()), child_groups); std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); @@ -367,28 +368,28 @@ TEST_F(OptimizerTests, PushFilterThroughJoinTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Node()->GetOpType()); + auto join_op = group_expr->Node()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - auto get_op = l_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + auto get_op = l_group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Op().GetType()); - auto filter_op = r_group_expr->Op().As(); + EXPECT_EQ(OpType::LogicalFilter, r_group_expr->Node()->GetOpType()); + auto filter_op = r_group_expr->Node()->As(); EXPECT_EQ(1, filter_op->predicates.size()); EXPECT_TRUE(filter_op->predicates[0].expr->ExactlyEquals(*predicates[1])); // Check get below filter group_expr = GetSingleGroupExpression(memo, r_group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - get_op = group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + get_op = group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); txn_manager.CommitTransaction(txn); @@ -435,7 +436,8 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { std::vector child_groups = {gexpr->GetGroupID()}; std::shared_ptr head_gexpr = - std::make_shared(Operator(), child_groups); + std::make_shared( + std::make_shared(Operator()), child_groups); std::shared_ptr root_context = std::make_shared(&(optimizer.GetMetadata()), nullptr); @@ -454,21 +456,21 @@ TEST_F(OptimizerTests, PredicatePushDownRewriteTest) { // Check join in the root auto group_expr = GetSingleGroupExpression(memo, head_gexpr.get(), 0); - EXPECT_EQ(OpType::InnerJoin, group_expr->Op().GetType()); - auto join_op = group_expr->Op().As(); + EXPECT_EQ(OpType::InnerJoin, group_expr->Node()->GetOpType()); + auto join_op = group_expr->Node()->As(); EXPECT_EQ(1, join_op->join_predicates.size()); EXPECT_TRUE(join_op->join_predicates[0].expr->ExactlyEquals(*predicates[0])); // Check left get auto l_group_expr = GetSingleGroupExpression(memo, group_expr, 0); - EXPECT_EQ(OpType::Get, l_group_expr->Op().GetType()); - auto get_op = l_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, l_group_expr->Node()->GetOpType()); + auto get_op = l_group_expr->Node()->As(); EXPECT_TRUE(get_op->predicates.empty()); // Check right filter auto r_group_expr = GetSingleGroupExpression(memo, group_expr, 1); - EXPECT_EQ(OpType::Get, r_group_expr->Op().GetType()); - get_op = r_group_expr->Op().As(); + EXPECT_EQ(OpType::Get, r_group_expr->Node()->GetOpType()); + get_op = r_group_expr->Node()->As(); EXPECT_EQ(1, get_op->predicates.size()); EXPECT_TRUE(get_op->predicates[0].expr->ExactlyEquals(*predicates[1])); diff --git a/test/optimizer/rewriter_test.cpp b/test/optimizer/rewriter_test.cpp new file mode 100644 index 00000000000..ac9ea9f9d7c --- /dev/null +++ b/test/optimizer/rewriter_test.cpp @@ -0,0 +1,427 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// operator_test.cpp +// +// Identification: test/optimizer/operator_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include "common/harness.h" + +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "expression/operator_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class RewriterTests : public PelotonTest {}; + +TEST_F(RewriterTests, SingleCompareEqualRewritePassFalse) { + // 3 = 2 ==> FALSE + type::Value leftValue = type::ValueFactory::GetIntegerValue(3); + type::Value rightValue = type::ValueFactory::GetIntegerValue(2); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +TEST_F(RewriterTests, SingleCompareEqualRewritePassTrue) { + // 4 = 4 ==> TRUE + type::Value leftValue = type::ValueFactory::GetIntegerValue(4); + type::Value rightValue = type::ValueFactory::GetIntegerValue(4); + auto left = new expression::ConstantValueExpression(leftValue); + auto right = new expression::ConstantValueExpression(rightValue); + auto common = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(common); + + delete rewriter; + delete common; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + +TEST_F(RewriterTests, SimpleEqualityTree) { + // [=] + // [=] [=] ==> FALSE + // [4] [5] [3] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val5); + auto rb_left_child = new expression::ConstantValueExpression(val3); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + +TEST_F(RewriterTests, ComparativeOperatorTest) { + // [=] + // [<=] [>=] ==> TRUE + // [4] [4] [5] [3] + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + +TEST_F(RewriterTests, BasicAndShortCircuitTest) { + + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [AND] + // [FALSE] [=] + // [X] [3] + // + // Intended output: [FALSE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_false); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [AND] + // [TRUE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_true); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_AND); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, BasicOrShortCircuitTest) { + // First, build the rewriter and the values that will be used in test cases + Rewriter *rewriter = new Rewriter(); + + type::Value val_false = type::ValueFactory::GetBooleanValue(false); + type::Value val_true = type::ValueFactory::GetBooleanValue(true); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + // + // [OR] + // [TRUE] [=] + // [X] [3] + // + // Intended output: [TRUE] + // + + expression::ConstantValueExpression *lh = new expression::ConstantValueExpression(val_true); + expression::ConstantValueExpression *rh_right_child = new expression::ConstantValueExpression(val3); + expression::TupleValueExpression *rh_left_child = new expression::TupleValueExpression("t","x"); + + expression::ComparisonExpression *rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + expression::ConjunctionExpression *root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + expression::AbstractExpression *rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + delete rewrote; + delete root; + + // + // [OR] + // [FALSE] [=] + // [X] [3] + // + // Intended output: same as input + // + + lh = new expression::ConstantValueExpression(val_false); + rh_right_child = new expression::ConstantValueExpression(val3); + rh_left_child = new expression::TupleValueExpression("t","x"); + + rh = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, rh_left_child, rh_right_child); + root = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lh, rh); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 2); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::CONJUNCTION_OR); + + delete rewrote; + delete root; + + delete rewriter; +} + + +TEST_F(RewriterTests, AndShortCircuitComparatorEliminationMixTest) { + // [AND] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: FALSE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + delete rewrote; +} + + +TEST_F(RewriterTests, OrShortCircuitComparatorEliminationMixTest) { + // [OR] + // [<=] [=] + // [4] [4] [5] [3] + // Intended Output: TRUE + // + type::Value val4 = type::ValueFactory::GetIntegerValue(4); + type::Value val5 = type::ValueFactory::GetIntegerValue(5); + type::Value val3 = type::ValueFactory::GetIntegerValue(3); + + auto lb_left_child = new expression::ConstantValueExpression(val4); + auto lb_right_child = new expression::ConstantValueExpression(val4); + auto rb_left_child = new expression::ConstantValueExpression(val5); + auto rb_right_child = new expression::ConstantValueExpression(val3); + + auto lb = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, + lb_left_child, lb_right_child); + auto rb = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, + rb_left_child, rb_right_child); + auto top = new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_OR, lb, rb); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(top); + + delete rewriter; + delete top; + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_TRUE(rewrote->GetChildrenSize() == 0); + EXPECT_TRUE(rewrote->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted->GetValueType() == type::TypeId::BOOLEAN); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + delete rewrote; +} + + +TEST_F(RewriterTests, NotNullColumnsTest) { + + // First, build rewriter to be used in all test cases + Rewriter *rewriter = new Rewriter(); + + // [T.X IS NULL], where X is a non-NULL column in table T + // Intended output: FALSE + + auto child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + auto root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + auto rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + auto casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), false); + + delete root; + delete rewrote; + + // [T.X IS NOT NULL], where X is a non-NULL column in table T + // Intended output: TRUE + + child = new expression::TupleValueExpression("t","x"); + child->SetIsNotNull(true); + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_TRUE(rewrote != nullptr); + EXPECT_EQ(rewrote->GetChildrenSize(), 0); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::VALUE_CONSTANT); + + casted = dynamic_cast(rewrote); + EXPECT_EQ(casted->GetValueType(), type::TypeId::BOOLEAN); + EXPECT_EQ(type::ValuePeeker::PeekBoolean(casted->GetValue()), true); + + delete root; + delete rewrote; + + // [T.Y IS NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NULL); + + delete root; + delete rewrote; + + // [T.Y IS NOT NULL], where Y is a possibly NULL column in table T + // Intended output: same as input + + child = new expression::TupleValueExpression("t","y"); + child->SetIsNotNull(false); // is_not_null is false by default, but explicitly setting it is for readability's sake + root = new expression::OperatorExpression(ExpressionType::OPERATOR_IS_NOT_NULL, type::TypeId::BOOLEAN, child, nullptr); + + rewrote = rewriter->RewriteExpression(root); + + EXPECT_EQ(rewrote->GetChildrenSize(), 1); + EXPECT_EQ(rewrote->GetExpressionType(), ExpressionType::OPERATOR_IS_NOT_NULL); + + delete root; + delete rewrote; + + delete rewriter; +} + + +} // namespace test +} // namespace peloton diff --git a/test/optimizer/rule_rewrite_test.cpp b/test/optimizer/rule_rewrite_test.cpp new file mode 100644 index 00000000000..49cee9af835 --- /dev/null +++ b/test/optimizer/rule_rewrite_test.cpp @@ -0,0 +1,521 @@ +//===----------------------------------------------------------------------===// +// +// Peloton +// +// rule_rewrite_test.cpp +// +// Identification: test/optimizer/rule_rewrite_test.cpp +// +// Copyright (c) 2015-2018, Carnegie Mellon University Database Group +// +//===----------------------------------------------------------------------===// + +#include +#include "common/harness.h" + +#include "optimizer/operators.h" +#include "optimizer/rewriter.h" +#include "expression/constant_value_expression.h" +#include "expression/comparison_expression.h" +#include "expression/tuple_value_expression.h" +#include "type/value_factory.h" +#include "type/value_peeker.h" +#include "optimizer/rule_rewrite.h" + +namespace peloton { + +namespace test { + +using namespace optimizer; + +class RuleRewriteTests : public PelotonTest { + public: + // Creates expresson: (A = X) AND (B = Y) + expression::AbstractExpression *CreateMultiLevelExpression(expression::AbstractExpression *a, + expression::AbstractExpression *x, + expression::AbstractExpression *b, + expression::AbstractExpression *y) { + auto left_eq = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, a->Copy(), x->Copy() + ); + + auto right_eq = new expression::ComparisonExpression( + ExpressionType::COMPARE_EQUAL, b->Copy(), y->Copy() + ); + + return new expression::ConjunctionExpression(ExpressionType::CONJUNCTION_AND, left_eq, right_eq); + } + + expression::ConstantValueExpression *GetConstantExpression(int val) { + auto value = type::ValueFactory::GetIntegerValue(val); + return new expression::ConstantValueExpression(value); + } +}; + +TEST_F(RuleRewriteTests, ComparatorEliminationEqual) { + // (1 == 1) => (TRUE) + auto left = GetConstantExpression(1); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left, right); + + // (1 == 2) => (FALSE) + auto left_f = GetConstantExpression(1); + auto right_f = GetConstantExpression(2); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_EQUAL, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + + delete rewrote; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationNotEqual) { + // (1 != 1) => (FALSE) + auto left = GetConstantExpression(1); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_NOTEQUAL, left, right); + + // (1 != 2) => (TRUE) + auto left_f = GetConstantExpression(1); + auto right_f = GetConstantExpression(2); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_NOTEQUAL, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThan) { + // (0 < 1) => (TRUE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left, right); + + // (1 < 1) => (FALSE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left_ef, right_ef); + + // (2 < 1) => (FALSE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHAN, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationGreaterThan) { + // (0 > 1) => (FALSE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left, right); + + // (1 < 1) => (FALSE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left_ef, right_ef); + + // (2 > 1) => (TRUE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHAN, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == false); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThanOrEqualTo) { + // (0 <= 1) => (TRUE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left, right); + + // (1 <= 1) => (TRUE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left_ef, right_ef); + + // (2 <= 1) => (FALSE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == true); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == false); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationGreaterThanOrEqualTo) { + // (0 >= 1) => (FALSE) + auto left = GetConstantExpression(0); + auto right = GetConstantExpression(1); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left, right); + + // (1 >= 1) => (TRUE) + auto left_ef = GetConstantExpression(1); + auto right_ef = GetConstantExpression(1); + auto equal_ef = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left_ef, right_ef); + + // (2 >= 1) => (TRUE) + auto left_f = GetConstantExpression(2); + auto right_f = GetConstantExpression(1); + auto equal_f = new expression::ComparisonExpression(ExpressionType::COMPARE_GREATERTHANOREQUALTO, left_f, right_f); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + auto rewrote_ef = rewriter->RewriteExpression(equal_ef); + auto rewrote_f = rewriter->RewriteExpression(equal_f); + delete rewriter; + delete equal; + delete equal_ef; + delete equal_f; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted->GetValue()) == false); + + auto casted_ef = dynamic_cast(rewrote_ef); + EXPECT_TRUE(casted_ef != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_ef->GetValue()) == true); + + auto casted_f = dynamic_cast(rewrote_f); + EXPECT_TRUE(casted_f != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(casted_f->GetValue()) == true); + + delete rewrote; + delete rewrote_ef; + delete rewrote_f; +} + +TEST_F(RuleRewriteTests, ComparatorEliminationLessThanOrEqualToNull) { + auto valNULL = type::ValueFactory::GetNullValueByType(type::TypeId::INTEGER); + + // 0 <= NULL => NULL + auto left = GetConstantExpression(2); + auto right = new expression::ConstantValueExpression(valNULL); + auto equal = new expression::ComparisonExpression(ExpressionType::COMPARE_LESSTHANOREQUALTO, left, right); + + Rewriter *rewriter = new Rewriter(); + auto rewrote = rewriter->RewriteExpression(equal); + delete rewriter; + delete equal; + + auto casted = dynamic_cast(rewrote); + EXPECT_TRUE(casted != nullptr); + + auto value = casted->GetValue(); + EXPECT_TRUE(value.GetTypeId() == type::TypeId::BOOLEAN); + EXPECT_TRUE(value.IsNull()); + + delete rewrote; +} + +TEST_F(RuleRewriteTests, TVEqualTwoCVFalseTransform) { + auto cv1 = GetConstantExpression(1); + auto cv2 = GetConstantExpression(2); + auto tv_base = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base: (A.B = 1) AND (A.B = 2) + auto base = CreateMultiLevelExpression(tv_base, cv1, tv_base, cv2); + + // Inverse: (1 = A.B) AND (2 = A.B) + auto inverse = CreateMultiLevelExpression(cv1, tv_base, cv2, tv_base); + + // Inner Flip Left: (1 = A.B) AND (A.B = 2) + auto if_left = CreateMultiLevelExpression(cv1, tv_base, tv_base, cv2); + + // Inner Flip Right: (A.B = 1) AND (2 = A.B) + auto if_right = CreateMultiLevelExpression(tv_base, cv1, cv2, tv_base); + + std::vector rewrites; + rewrites.push_back(rewriter->RewriteExpression(base)); + rewrites.push_back(rewriter->RewriteExpression(inverse)); + rewrites.push_back(rewriter->RewriteExpression(if_left)); + rewrites.push_back(rewriter->RewriteExpression(if_right)); + delete rewriter; + delete cv1; + delete cv2; + delete tv_base; + delete base; + delete inverse; + delete if_left; + delete if_right; + + for (auto expr : rewrites) { + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT); + EXPECT_TRUE(expr->GetChildrenSize() == 0); + + auto expr_val = dynamic_cast(expr); + EXPECT_TRUE(expr_val != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekBoolean(expr_val->GetValue()) == false); + } + + while (!rewrites.empty()) { + auto expr = rewrites.back(); + rewrites.pop_back(); + delete expr; + } +} + +TEST_F(RuleRewriteTests, TVEqualTwoCVTrueTransform) { + auto cv1 = GetConstantExpression(1); + auto tv_base = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base: (A.B = 1) AND (A.B = 1) + auto base = CreateMultiLevelExpression(tv_base, cv1, tv_base, cv1); + + // Inverse: (1 = A.B) AND (1 = A.B) + auto inverse = CreateMultiLevelExpression(cv1, tv_base, cv1, tv_base); + + // Inner Flip Left: (1 = A.B) AND (A.B = 1) + auto if_left = CreateMultiLevelExpression(cv1, tv_base, tv_base, cv1); + + // Inner Flip Right: (A.B = 1) AND (1 = A.B) + auto if_right = CreateMultiLevelExpression(tv_base, cv1, cv1, tv_base); + + std::vector rewrites; + rewrites.push_back(rewriter->RewriteExpression(base)); + rewrites.push_back(rewriter->RewriteExpression(inverse)); + rewrites.push_back(rewriter->RewriteExpression(if_left)); + rewrites.push_back(rewriter->RewriteExpression(if_right)); + delete rewriter; + delete cv1; + delete base; + delete inverse; + delete if_left; + delete if_right; + + for (auto expr : rewrites) { + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_tv = expr->GetModifiableChild(0); + auto tv = dynamic_cast(left_tv); + EXPECT_TRUE(tv != nullptr); + EXPECT_TRUE(tv->ExactlyEquals(*tv_base)); + + auto right_cv = expr->GetModifiableChild(1); + auto cv = dynamic_cast(right_cv); + EXPECT_TRUE(cv != nullptr); + EXPECT_TRUE(type::ValuePeeker::PeekInteger(cv->GetValue()) == 1); + } + + while (!rewrites.empty()) { + auto expr = rewrites.back(); + rewrites.pop_back(); + delete expr; + } + + delete tv_base; +} + +TEST_F(RuleRewriteTests, TransitiveClosureUnableTest) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + auto tv_base2 = new expression::TupleValueExpression("C", "A"); + auto tv_base3 = new expression::TupleValueExpression("D", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (B = C) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base2, tv_base3); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_eq = expr->GetModifiableChild(0); + auto right_eq = expr->GetModifiableChild(1); + EXPECT_TRUE(left_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(right_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(left_eq->GetChildrenSize() == 2); + EXPECT_TRUE(right_eq->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(left_eq->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(left_eq->GetModifiableChild(1)); + auto rl_tv = dynamic_cast(right_eq->GetModifiableChild(0)); + auto rr_tv = dynamic_cast(right_eq->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr && rl_tv != nullptr && rr_tv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + EXPECT_TRUE(rl_tv->ExactlyEquals(*tv_base2)); + EXPECT_TRUE(rr_tv->ExactlyEquals(*tv_base3)); + + delete cv1; + delete tv_base1; + delete tv_base2; + delete tv_base3; + delete expr; +} + +TEST_F(RuleRewriteTests, TransitiveClosureRewrite) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + auto tv_base2 = new expression::TupleValueExpression("C", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (A = B) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base1, tv_base2); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto left_eq = expr->GetModifiableChild(0); + auto right_eq = expr->GetModifiableChild(1); + EXPECT_TRUE(left_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(right_eq->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(left_eq->GetChildrenSize() == 2); + EXPECT_TRUE(right_eq->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(left_eq->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(left_eq->GetModifiableChild(1)); + auto rl_cv = dynamic_cast(right_eq->GetModifiableChild(0)); + auto rr_tv = dynamic_cast(right_eq->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr && rl_cv != nullptr && rr_tv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + EXPECT_TRUE(rl_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(rr_tv->ExactlyEquals(*tv_base2)); + + delete cv1; + delete tv_base1; + delete tv_base2; + delete expr; +} + +TEST_F(RuleRewriteTests, TransitiveClosureHalfTrue) { + auto cv1 = GetConstantExpression(1); + auto tv_base1 = new expression::TupleValueExpression("B", "A"); + + Rewriter *rewriter = new Rewriter(); + + // Base (A = 1) AND (A = B) + auto base = CreateMultiLevelExpression(tv_base1, cv1, tv_base1, tv_base1); + + auto expr = rewriter->RewriteExpression(base); + delete rewriter; + delete base; + + // Returned expression should not be changed + EXPECT_TRUE(expr->GetExpressionType() == ExpressionType::COMPARE_EQUAL); + EXPECT_TRUE(expr->GetChildrenSize() == 2); + + auto ll_tv = dynamic_cast(expr->GetModifiableChild(0)); + auto lr_cv = dynamic_cast(expr->GetModifiableChild(1)); + EXPECT_TRUE(ll_tv != nullptr && lr_cv != nullptr); + EXPECT_TRUE(lr_cv->ExactlyEquals(*cv1)); + EXPECT_TRUE(ll_tv->ExactlyEquals(*tv_base1)); + + delete cv1; + delete tv_base1; + delete expr; +} + +} // namespace test +} // namespace peloton