From 63c7ca4df2970d12574ad3b542ec17eb5276ef86 Mon Sep 17 00:00:00 2001 From: Anton Lykov Date: Tue, 17 Dec 2024 22:39:11 +0800 Subject: [PATCH] [SPARK-50597][SQL] Refactor batch construction in Optimizer.scala and SparkOptimizer.scala ### What changes were proposed in this pull request? See description. Previously, it was a pain to reorder batches and guard behavior of certain batches / sequences of batches by a flag. This was primarily due to ample usage of `::`, `:::`, and `:+` to juggle rules and batches around which imposed syntactic limitations. After this change, we keep a single sequence `allBatches`, that can contain either `Batch` or `Seq[Batch]` elements to allow further groupings, which is later flattened into a single `Seq[Batch]`. We avoid any usage of `::`, `:::`, and `:+`. To add/replace a flag-guarded batch of sequence of batches, write a function that returns either `Batch` of `Seq[Batch]` with desired behavior, and add/replace in the relevant place in the `allBatches` list. ### Why are the changes needed? This simplifies further restructuring and reordering of batches. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No tests. ### Was this patch authored or co-authored using generative AI tooling? Closes #49208 from anton5798/batch-refactor. Authored-by: Anton Lykov Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 113 ++++++++++-------- .../spark/sql/execution/SparkOptimizer.scala | 63 +++++----- 2 files changed, 98 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 31c1f89177632..b141d2be04c32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -73,6 +73,21 @@ abstract class Optimizer(catalogManager: CatalogManager) conf.optimizerMaxIterations, maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key) + /** + * A helper method that takes as input a Seq of Batch or Seq[Batch], and flattens it out. + */ + def flattenBatches(nestedBatchSequence: Seq[Any]): Seq[Batch] = { + assert(nestedBatchSequence.forall { + case _: Batch => true + case s: Seq[_] => s.forall(_.isInstanceOf[Batch]) + case _ => false + }) + nestedBatchSequence.flatMap { + case batches: Seq[Batch @unchecked] => batches + case batch: Batch => Seq(batch) + } + } + /** * Defines the default rule batches in the Optimizer. * @@ -143,39 +158,38 @@ abstract class Optimizer(catalogManager: CatalogManager) PushdownPredicatesAndPruneColumnsForCTEDef) ++ extendedOperatorOptimizationRules - val operatorOptimizationBatch: Seq[Batch] = { + val operatorOptimizationBatch: Seq[Batch] = Seq( Batch("Operator Optimization before Inferring Filters", fixedPoint, - operatorOptimizationRuleSet: _*) :: + operatorOptimizationRuleSet: _*), Batch("Infer Filters", Once, InferFiltersFromGenerate, - InferFiltersFromConstraints) :: + InferFiltersFromConstraints), Batch("Operator Optimization after Inferring Filters", fixedPoint, - operatorOptimizationRuleSet: _*) :: + operatorOptimizationRuleSet: _*), Batch("Push extra predicate through join", fixedPoint, PushExtraPredicateThroughJoin, - PushDownPredicates) :: Nil - } + PushDownPredicates)) - val batches = ( - Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) :: + val batches: Seq[Batch] = flattenBatches(Seq( + Batch("Finish Analysis", FixedPoint(1), FinishAnalysis), // We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression // may produce `With` expressions that need to be rewritten. - Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) :: + Batch("Rewrite With expression", fixedPoint, RewriteWithExpression), ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// - Batch("Eliminate Distinct", Once, EliminateDistinct) :: + Batch("Eliminate Distinct", Once, EliminateDistinct), // - Do the first call of CombineUnions before starting the major Optimizer rules, // since it can reduce the number of iteration and the other rules could add/move // extra operators between two adjacent Union operators. // - Call CombineUnions again in Batch("Operator Optimizations"), // since the other rules might make two separate Unions operators adjacent. Batch("Inline CTE", Once, - InlineCTE()) :: + InlineCTE()), Batch("Union", fixedPoint, RemoveNoopOperators, CombineUnions, - RemoveNoopUnion) :: + RemoveNoopUnion), // Run this once earlier. This might simplify the plan and reduce cost of optimizer. // For example, a query such as Filter(LocalRelation) would go through all the heavy // optimizer rules that are triggered when there is a filter @@ -186,16 +200,16 @@ abstract class Optimizer(catalogManager: CatalogManager) PropagateEmptyRelation, // PropagateEmptyRelation can change the nullability of an attribute from nullable to // non-nullable when an empty relation child of a Union is removed - UpdateAttributeNullability) :: + UpdateAttributeNullability), Batch("Pullup Correlated Expressions", Once, OptimizeOneRowRelationSubquery, PullOutNestedDataOuterRefExpressions, - PullupCorrelatedPredicates) :: + PullupCorrelatedPredicates), // Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense // to enforce idempotence on it and we change this batch from Once to FixedPoint(1). Batch("Subquery", FixedPoint(1), OptimizeSubqueries, - OptimizeOneRowRelationSubquery) :: + OptimizeOneRowRelationSubquery), Batch("Replace Operators", fixedPoint, RewriteExceptAll, RewriteIntersectAll, @@ -203,48 +217,48 @@ abstract class Optimizer(catalogManager: CatalogManager) ReplaceExceptWithFilter, ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate, - ReplaceDeduplicateWithAggregate) :: + ReplaceDeduplicateWithAggregate), Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions, - RemoveRepetitionFromGroupExpressions) :: Nil ++ - operatorOptimizationBatch) :+ - Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+ + RemoveRepetitionFromGroupExpressions), + operatorOptimizationBatch, + Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo), // This batch rewrites plans after the operator optimization and // before any batches that depend on stats. - Batch("Pre CBO Rules", Once, preCBORules: _*) :+ + Batch("Pre CBO Rules", Once, preCBORules: _*), // This batch pushes filters and projections into scan nodes. Before this batch, the logical // plan may contain nodes that do not report stats. Anything that uses stats must run after // this batch. - Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*) :+ - Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats) :+ + Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*), + Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats), // Since join costs in AQP can change between multiple runs, there is no reason that we have an // idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once. Batch("Join Reorder", FixedPoint(1), - CostBasedJoinReorder) :+ + CostBasedJoinReorder), Batch("Eliminate Sorts", Once, EliminateSorts, - RemoveRedundantSorts) :+ + RemoveRedundantSorts), Batch("Decimal Optimizations", fixedPoint, - DecimalAggregates) :+ + DecimalAggregates), // This batch must run after "Decimal Optimizations", as that one may change the // aggregate distinct column Batch("Distinct Aggregate Rewrite", Once, - RewriteDistinctAggregates) :+ + RewriteDistinctAggregates), Batch("Object Expressions Optimization", fixedPoint, EliminateMapObjects, CombineTypedFilters, ObjectSerializerPruning, - ReassignLambdaVariableID) :+ + ReassignLambdaVariableID), Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation, // PropagateEmptyRelation can change the nullability of an attribute from nullable to // non-nullable when an empty relation child of a Union is removed - UpdateAttributeNullability) :+ - Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+ + UpdateAttributeNullability), + Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan), // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, - CheckCartesianProducts) :+ + CheckCartesianProducts), Batch("RewriteSubquery", Once, RewritePredicateSubquery, PushPredicateThroughJoin, @@ -252,10 +266,10 @@ abstract class Optimizer(catalogManager: CatalogManager) ColumnPruning, CollapseProject, RemoveRedundantAliases, - RemoveNoopOperators) :+ + RemoveNoopOperators), // This batch must be executed after the `RewriteSubquery` batch, which creates joins. - Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ - Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers), + Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression))) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -270,22 +284,23 @@ abstract class Optimizer(catalogManager: CatalogManager) * (defaultBatches - (excludedRules - nonExcludableRules)). */ def nonExcludableRules: Seq[String] = - FinishAnalysis.ruleName :: - RewriteDistinctAggregates.ruleName :: - ReplaceDeduplicateWithAggregate.ruleName :: - ReplaceIntersectWithSemiJoin.ruleName :: - ReplaceExceptWithFilter.ruleName :: - ReplaceExceptWithAntiJoin.ruleName :: - RewriteExceptAll.ruleName :: - RewriteIntersectAll.ruleName :: - ReplaceDistinctWithAggregate.ruleName :: - PullupCorrelatedPredicates.ruleName :: - RewriteCorrelatedScalarSubquery.ruleName :: - RewritePredicateSubquery.ruleName :: - NormalizeFloatingNumbers.ruleName :: - ReplaceUpdateFieldsExpression.ruleName :: - RewriteLateralSubquery.ruleName :: - OptimizeSubqueries.ruleName :: Nil + Seq( + FinishAnalysis.ruleName, + RewriteDistinctAggregates.ruleName, + ReplaceDeduplicateWithAggregate.ruleName, + ReplaceIntersectWithSemiJoin.ruleName, + ReplaceExceptWithFilter.ruleName, + ReplaceExceptWithAntiJoin.ruleName, + RewriteExceptAll.ruleName, + RewriteIntersectAll.ruleName, + ReplaceDistinctWithAggregate.ruleName, + PullupCorrelatedPredicates.ruleName, + RewriteCorrelatedScalarSubquery.ruleName, + RewritePredicateSubquery.ruleName, + NormalizeFloatingNumbers.ruleName, + ReplaceUpdateFieldsExpression.ruleName, + RewriteLateralSubquery.ruleName, + OptimizeSubqueries.ruleName) /** * Apply finish-analysis rules for the entire plan including all subqueries. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 6173703ef3cd9..6ceb363b41aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -36,38 +36,41 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - Seq(SchemaPruning) :+ - GroupBasedRowLevelOperationScanPlanning :+ - V1Writes :+ - V2ScanRelationPushDown :+ - V2ScanPartitioningAndOrdering :+ - V2Writes :+ - PruneFileSourcePartitions + Seq( + SchemaPruning, + GroupBasedRowLevelOperationScanPlanning, + V1Writes, + V2ScanRelationPushDown, + V2ScanPartitioningAndOrdering, + V2Writes, + PruneFileSourcePartitions) override def preCBORules: Seq[Rule[LogicalPlan]] = - OptimizeMetadataOnlyDeleteFromTable :: Nil + Seq(OptimizeMetadataOnlyDeleteFromTable) - override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ - Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ + override def defaultBatches: Seq[Batch] = flattenBatches(Seq( + preOptimizationBatches, + super.defaultBatches, + Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)), Batch("PartitionPruning", Once, PartitionPruning, // We can't run `OptimizeSubqueries` in this batch, as it will optimize the subqueries // twice which may break some optimizer rules that can only be applied once. The rule below // only invokes `OptimizeSubqueries` to optimize newly added subqueries. - new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+ + new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)), Batch("InjectRuntimeFilter", FixedPoint(1), - InjectRuntimeFilter) :+ + InjectRuntimeFilter), Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries, - RewriteDistinctAggregates) :+ + RewriteDistinctAggregates), Batch("Pushdown Filters from PartitionPruning", fixedPoint, - PushDownPredicates) :+ + PushDownPredicates), Batch("Cleanup filters that cannot be pushed down", Once, CleanupDynamicPruningFilters, // cleanup the unnecessary TrueLiteral predicates BooleanSimplification, - PruneFilters)) ++ - postHocOptimizationBatches :+ + PruneFilters), + postHocOptimizationBatches, Batch("Extract Python UDFs", Once, ExtractPythonUDFFromJoinCondition, // `ExtractPythonUDFFromJoinCondition` can convert a join to a cartesian product. @@ -84,25 +87,27 @@ class SparkOptimizer( LimitPushDown, PushPredicateThroughNonJoin, PushProjectionThroughLimit, - RemoveNoopOperators) :+ + RemoveNoopOperators), Batch("Infer window group limit", Once, InferWindowGroupLimit, LimitPushDown, LimitPushDownThroughWindow, EliminateLimits, - ConstantFolding) :+ - Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+ - Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) + ConstantFolding), + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*), + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition))) - override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ - ExtractPythonUDFFromJoinCondition.ruleName :+ - ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ - ExtractPythonUDFs.ruleName :+ - GroupBasedRowLevelOperationScanPlanning.ruleName :+ - V2ScanRelationPushDown.ruleName :+ - V2ScanPartitioningAndOrdering.ruleName :+ - V2Writes.ruleName :+ - ReplaceCTERefWithRepartition.ruleName + override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++ + Seq( + ExtractPythonUDFFromJoinCondition.ruleName, + ExtractPythonUDFFromAggregate.ruleName, + ExtractGroupingPythonUDFFromAggregate.ruleName, + ExtractPythonUDFs.ruleName, + GroupBasedRowLevelOperationScanPlanning.ruleName, + V2ScanRelationPushDown.ruleName, + V2ScanPartitioningAndOrdering.ruleName, + V2Writes.ruleName, + ReplaceCTERefWithRepartition.ruleName) /** * Optimization batches that are executed before the regular optimization batches (also before