Skip to content

Commit

Permalink
[SPARK-50597][SQL] Refactor batch construction in Optimizer.scala and…
Browse files Browse the repository at this point in the history
… SparkOptimizer.scala

### What changes were proposed in this pull request?

See description. Previously, it was a pain to reorder batches and guard behavior of certain batches / sequences of batches by a flag. This was primarily due to ample usage of `::`, `:::`, and `:+` to juggle rules and batches around which imposed syntactic limitations.

After this change, we keep a single sequence `allBatches`, that can contain either `Batch` or `Seq[Batch]` elements to allow further groupings, which is later flattened into a single `Seq[Batch]`.

We avoid any usage of `::`, `:::`, and `:+`.

To add/replace a flag-guarded batch of sequence of batches, write a function that returns either `Batch` of `Seq[Batch]` with desired behavior, and add/replace in the relevant place in the `allBatches` list.

### Why are the changes needed?

This simplifies further restructuring and reordering of batches.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

No tests.

### Was this patch authored or co-authored using generative AI tooling?

Closes #49208 from anton5798/batch-refactor.

Authored-by: Anton Lykov <anton.lykov@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
anton5798 authored and cloud-fan committed Dec 17, 2024
1 parent 79026ad commit 63c7ca4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ abstract class Optimizer(catalogManager: CatalogManager)
conf.optimizerMaxIterations,
maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key)

/**
* A helper method that takes as input a Seq of Batch or Seq[Batch], and flattens it out.
*/
def flattenBatches(nestedBatchSequence: Seq[Any]): Seq[Batch] = {
assert(nestedBatchSequence.forall {
case _: Batch => true
case s: Seq[_] => s.forall(_.isInstanceOf[Batch])
case _ => false
})
nestedBatchSequence.flatMap {
case batches: Seq[Batch @unchecked] => batches
case batch: Batch => Seq(batch)
}
}

/**
* Defines the default rule batches in the Optimizer.
*
Expand Down Expand Up @@ -143,39 +158,38 @@ abstract class Optimizer(catalogManager: CatalogManager)
PushdownPredicatesAndPruneColumnsForCTEDef) ++
extendedOperatorOptimizationRules

val operatorOptimizationBatch: Seq[Batch] = {
val operatorOptimizationBatch: Seq[Batch] = Seq(
Batch("Operator Optimization before Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*) ::
operatorOptimizationRuleSet: _*),
Batch("Infer Filters", Once,
InferFiltersFromGenerate,
InferFiltersFromConstraints) ::
InferFiltersFromConstraints),
Batch("Operator Optimization after Inferring Filters", fixedPoint,
operatorOptimizationRuleSet: _*) ::
operatorOptimizationRuleSet: _*),
Batch("Push extra predicate through join", fixedPoint,
PushExtraPredicateThroughJoin,
PushDownPredicates) :: Nil
}
PushDownPredicates))

val batches = (
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
val batches: Seq[Batch] = flattenBatches(Seq(
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis),
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
// may produce `With` expressions that need to be rewritten.
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
Batch("Rewrite With expression", fixedPoint, RewriteWithExpression),
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Batch("Eliminate Distinct", Once, EliminateDistinct) ::
Batch("Eliminate Distinct", Once, EliminateDistinct),
// - Do the first call of CombineUnions before starting the major Optimizer rules,
// since it can reduce the number of iteration and the other rules could add/move
// extra operators between two adjacent Union operators.
// - Call CombineUnions again in Batch("Operator Optimizations"),
// since the other rules might make two separate Unions operators adjacent.
Batch("Inline CTE", Once,
InlineCTE()) ::
InlineCTE()),
Batch("Union", fixedPoint,
RemoveNoopOperators,
CombineUnions,
RemoveNoopUnion) ::
RemoveNoopUnion),
// Run this once earlier. This might simplify the plan and reduce cost of optimizer.
// For example, a query such as Filter(LocalRelation) would go through all the heavy
// optimizer rules that are triggered when there is a filter
Expand All @@ -186,76 +200,76 @@ abstract class Optimizer(catalogManager: CatalogManager)
PropagateEmptyRelation,
// PropagateEmptyRelation can change the nullability of an attribute from nullable to
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) ::
UpdateAttributeNullability),
Batch("Pullup Correlated Expressions", Once,
OptimizeOneRowRelationSubquery,
PullOutNestedDataOuterRefExpressions,
PullupCorrelatedPredicates) ::
PullupCorrelatedPredicates),
// Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense
// to enforce idempotence on it and we change this batch from Once to FixedPoint(1).
Batch("Subquery", FixedPoint(1),
OptimizeSubqueries,
OptimizeOneRowRelationSubquery) ::
OptimizeOneRowRelationSubquery),
Batch("Replace Operators", fixedPoint,
RewriteExceptAll,
RewriteIntersectAll,
ReplaceIntersectWithSemiJoin,
ReplaceExceptWithFilter,
ReplaceExceptWithAntiJoin,
ReplaceDistinctWithAggregate,
ReplaceDeduplicateWithAggregate) ::
ReplaceDeduplicateWithAggregate),
Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) :: Nil ++
operatorOptimizationBatch) :+
Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+
RemoveRepetitionFromGroupExpressions),
operatorOptimizationBatch,
Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo),
// This batch rewrites plans after the operator optimization and
// before any batches that depend on stats.
Batch("Pre CBO Rules", Once, preCBORules: _*) :+
Batch("Pre CBO Rules", Once, preCBORules: _*),
// This batch pushes filters and projections into scan nodes. Before this batch, the logical
// plan may contain nodes that do not report stats. Anything that uses stats must run after
// this batch.
Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*) :+
Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats) :+
Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*),
Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats),
// Since join costs in AQP can change between multiple runs, there is no reason that we have an
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
Batch("Join Reorder", FixedPoint(1),
CostBasedJoinReorder) :+
CostBasedJoinReorder),
Batch("Eliminate Sorts", Once,
EliminateSorts,
RemoveRedundantSorts) :+
RemoveRedundantSorts),
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
DecimalAggregates),
// This batch must run after "Decimal Optimizations", as that one may change the
// aggregate distinct column
Batch("Distinct Aggregate Rewrite", Once,
RewriteDistinctAggregates) :+
RewriteDistinctAggregates),
Batch("Object Expressions Optimization", fixedPoint,
EliminateMapObjects,
CombineTypedFilters,
ObjectSerializerPruning,
ReassignLambdaVariableID) :+
ReassignLambdaVariableID),
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation,
// PropagateEmptyRelation can change the nullability of an attribute from nullable to
// non-nullable when an empty relation child of a Union is removed
UpdateAttributeNullability) :+
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+
UpdateAttributeNullability),
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan),
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) :+
CheckCartesianProducts),
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
PushPredicateThroughJoin,
LimitPushDown,
ColumnPruning,
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
RemoveNoopOperators),
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers),
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)))

// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
Expand All @@ -270,22 +284,23 @@ abstract class Optimizer(catalogManager: CatalogManager)
* (defaultBatches - (excludedRules - nonExcludableRules)).
*/
def nonExcludableRules: Seq[String] =
FinishAnalysis.ruleName ::
RewriteDistinctAggregates.ruleName ::
ReplaceDeduplicateWithAggregate.ruleName ::
ReplaceIntersectWithSemiJoin.ruleName ::
ReplaceExceptWithFilter.ruleName ::
ReplaceExceptWithAntiJoin.ruleName ::
RewriteExceptAll.ruleName ::
RewriteIntersectAll.ruleName ::
ReplaceDistinctWithAggregate.ruleName ::
PullupCorrelatedPredicates.ruleName ::
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName ::
RewriteLateralSubquery.ruleName ::
OptimizeSubqueries.ruleName :: Nil
Seq(
FinishAnalysis.ruleName,
RewriteDistinctAggregates.ruleName,
ReplaceDeduplicateWithAggregate.ruleName,
ReplaceIntersectWithSemiJoin.ruleName,
ReplaceExceptWithFilter.ruleName,
ReplaceExceptWithAntiJoin.ruleName,
RewriteExceptAll.ruleName,
RewriteIntersectAll.ruleName,
ReplaceDistinctWithAggregate.ruleName,
PullupCorrelatedPredicates.ruleName,
RewriteCorrelatedScalarSubquery.ruleName,
RewritePredicateSubquery.ruleName,
NormalizeFloatingNumbers.ruleName,
ReplaceUpdateFieldsExpression.ruleName,
RewriteLateralSubquery.ruleName,
OptimizeSubqueries.ruleName)

/**
* Apply finish-analysis rules for the entire plan including all subqueries.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,41 @@ class SparkOptimizer(

override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
GroupBasedRowLevelOperationScanPlanning :+
V1Writes :+
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
PruneFileSourcePartitions
Seq(
SchemaPruning,
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
V2ScanRelationPushDown,
V2ScanPartitioningAndOrdering,
V2Writes,
PruneFileSourcePartitions)

override def preCBORules: Seq[Rule[LogicalPlan]] =
OptimizeMetadataOnlyDeleteFromTable :: Nil
Seq(OptimizeMetadataOnlyDeleteFromTable)

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
override def defaultBatches: Seq[Batch] = flattenBatches(Seq(
preOptimizationBatches,
super.defaultBatches,
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)),
Batch("PartitionPruning", Once,
PartitionPruning,
// We can't run `OptimizeSubqueries` in this batch, as it will optimize the subqueries
// twice which may break some optimizer rules that can only be applied once. The rule below
// only invokes `OptimizeSubqueries` to optimize newly added subqueries.
new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+
new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)),
Batch("InjectRuntimeFilter", FixedPoint(1),
InjectRuntimeFilter) :+
InjectRuntimeFilter),
Batch("MergeScalarSubqueries", Once,
MergeScalarSubqueries,
RewriteDistinctAggregates) :+
RewriteDistinctAggregates),
Batch("Pushdown Filters from PartitionPruning", fixedPoint,
PushDownPredicates) :+
PushDownPredicates),
Batch("Cleanup filters that cannot be pushed down", Once,
CleanupDynamicPruningFilters,
// cleanup the unnecessary TrueLiteral predicates
BooleanSimplification,
PruneFilters)) ++
postHocOptimizationBatches :+
PruneFilters),
postHocOptimizationBatches,
Batch("Extract Python UDFs", Once,
ExtractPythonUDFFromJoinCondition,
// `ExtractPythonUDFFromJoinCondition` can convert a join to a cartesian product.
Expand All @@ -84,25 +87,27 @@ class SparkOptimizer(
LimitPushDown,
PushPredicateThroughNonJoin,
PushProjectionThroughLimit,
RemoveNoopOperators) :+
RemoveNoopOperators),
Batch("Infer window group limit", Once,
InferWindowGroupLimit,
LimitPushDown,
LimitPushDownThroughWindow,
EliminateLimits,
ConstantFolding) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
ConstantFolding),
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*),
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))

override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
ExtractPythonUDFFromJoinCondition.ruleName :+
ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+
ExtractPythonUDFs.ruleName :+
GroupBasedRowLevelOperationScanPlanning.ruleName :+
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioningAndOrdering.ruleName :+
V2Writes.ruleName :+
ReplaceCTERefWithRepartition.ruleName
override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++
Seq(
ExtractPythonUDFFromJoinCondition.ruleName,
ExtractPythonUDFFromAggregate.ruleName,
ExtractGroupingPythonUDFFromAggregate.ruleName,
ExtractPythonUDFs.ruleName,
GroupBasedRowLevelOperationScanPlanning.ruleName,
V2ScanRelationPushDown.ruleName,
V2ScanPartitioningAndOrdering.ruleName,
V2Writes.ruleName,
ReplaceCTERefWithRepartition.ruleName)

/**
* Optimization batches that are executed before the regular optimization batches (also before
Expand Down

0 comments on commit 63c7ca4

Please sign in to comment.