From 5f1b5c9d28017ee1e335ddbefbdd3547445cb2ae Mon Sep 17 00:00:00 2001 From: Tim Meehan Date: Tue, 30 Jul 2024 10:37:54 -0400 Subject: [PATCH] Add DelegatingRowExpressionOptimizer --- .../presto/SystemSessionProperties.java | 11 + .../presto/cost/ScalarStatsCalculator.java | 8 +- .../facebook/presto/server/PluginManager.java | 12 +- .../facebook/presto/server/PrestoServer.java | 3 + .../presto/server/ServerMainModule.java | 4 + .../server/testing/TestingPrestoServer.java | 9 + .../presto/sql/analyzer/FeaturesConfig.java | 29 ++ .../ExpressionOptimizerManager.java | 103 +++++++ .../presto/sql/planner/PlanOptimizers.java | 16 +- .../rule/SimplifyRowExpressions.java | 37 ++- .../CteProjectionAndPredicatePushDown.java | 12 +- .../DelegatingRowExpressionOptimizer.java | 90 ++++++ .../relational/RowExpressionOptimizer.java | 17 +- .../presto/testing/LocalQueryRunner.java | 22 +- .../facebook/presto/testing/QueryRunner.java | 3 + .../sql/analyzer/TestFeaturesConfig.java | 10 +- .../sql/planner/TestLogicalPlanner.java | 34 ++- .../planner/assertions/OptimizerAssert.java | 11 +- .../iterative/rule/TestRemoveMapCastRule.java | 31 +- ...teConstantArrayContainsToInExpression.java | 66 ++-- .../rule/TestSimplifyRowExpressions.java | 22 +- .../iterative/rule/test/BaseRuleTest.java | 6 + .../iterative/rule/test/RuleTester.java | 9 + ...TestCteProjectionAndPredicatePushdown.java | 3 +- .../TestDelegatingRowExpressionOptimizer.java | 223 ++++++++++++++ .../nativeworker/ContainerQueryRunner.java | 7 + .../presto/spark/PrestoSparkModule.java | 6 + .../presto/spark/PrestoSparkQueryRunner.java | 7 + .../relation/ExpressionOptimizerProvider.java | 19 ++ .../spi/relation/RowExpressionService.java | 1 + .../tests/AbstractTestQueryFramework.java | 10 +- .../presto/tests/DistributedQueryRunner.java | 8 + .../presto/tests/StandaloneQueryRunner.java | 7 + .../presto/memory/TestMemoryManager.java | 3 +- .../TestDelegatingExpressionOptimizer.java | 146 +++++++++ .../TestExpressionInterpreter.java | 225 ++++++++++++++ .../expressions/TestExpressionOptimizers.java | 133 ++++++++ .../tests/expressions/TestExpressions.java | 286 ++++-------------- .../thrift/integration/ThriftQueryRunner.java | 7 + 39 files changed, 1366 insertions(+), 290 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizerProvider.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java create mode 100644 presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java rename presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java => presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java (87%) diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 8ab8d397d521..e2e0ff1eeaa0 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -326,6 +326,7 @@ public final class SystemSessionProperties public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values"; public static final String INCLUDE_VALUES_NODE_IN_CONNECTOR_OPTIMIZER = "include_values_node_in_connector_optimizer"; public static final String SINGLE_NODE_EXECUTION_ENABLED = "single_node_execution_enabled"; + public static final String DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED = "delegating_row_expression_optimizer_enabled"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; @@ -1835,6 +1836,11 @@ public SystemSessionProperties( SINGLE_NODE_EXECUTION_ENABLED, "Enable single node execution", featuresConfig.isSingleNodeExecutionEnabled(), + false), + booleanProperty( + DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, + "Enable delegating row optimizer", + featuresConfig.isDelegatingRowExpressionOptimizerEnabled(), false)); } @@ -3121,4 +3127,9 @@ public static int getMinColumnarEncodingChannelsToPreferRowWiseEncoding(Session { return session.getSystemProperty(NATIVE_MIN_COLUMNAR_ENCODING_CHANNELS_TO_PREFER_ROW_WISE_ENCODING, Integer.class); } + + public static boolean isDelegatingRowExpressionOptimizerEnabled(Session session) + { + return session.getSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java index 044785ca0a44..cb5d8ec8c591 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -31,11 +31,11 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.ExpressionAnalyzer; import com.facebook.presto.sql.analyzer.Scope; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.NoOpVariableResolver; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.relational.FunctionResolution; -import com.facebook.presto.sql.relational.RowExpressionOptimizer; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.ArithmeticUnaryExpression; import com.facebook.presto.sql.tree.AstVisitor; @@ -78,11 +78,13 @@ public class ScalarStatsCalculator { private final Metadata metadata; + private final ExpressionOptimizerManager expressionOptimizerManager; @Inject - public ScalarStatsCalculator(Metadata metadata) + public ScalarStatsCalculator(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { this.metadata = requireNonNull(metadata, "metadata can not be null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager can not be null"); } @Deprecated @@ -126,7 +128,7 @@ public VariableStatsEstimate visitCall(CallExpression call, Void context) return computeArithmeticBinaryStatistics(call, context); } - RowExpression value = new RowExpressionOptimizer(metadata).optimize(call, OPTIMIZED, session); + RowExpression value = expressionOptimizerManager.getExpressionOptimizer().optimize(call, OPTIMIZED, session); if (isNull(value)) { return nullStatsEstimate(); diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index f1cef40a1a3f..5a3a06b3b97e 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.security.SystemAccessControlFactory; import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; import com.facebook.presto.spi.session.WorkerSessionPropertyProviderFactory; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; import com.facebook.presto.spi.statistics.HistoryBasedPlanStatisticsProvider; import com.facebook.presto.spi.storage.TempStorageFactory; import com.facebook.presto.spi.tracing.TracerProvider; @@ -50,6 +51,7 @@ import com.facebook.presto.spi.ttl.NodeTtlFetcherFactory; import com.facebook.presto.sql.analyzer.AnalyzerProviderManager; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; import com.facebook.presto.storage.TempStorageManager; import com.facebook.presto.tracing.TracerProviderManager; @@ -135,6 +137,7 @@ public class PluginManager private final QueryPreparerProviderManager queryPreparerProviderManager; private final NodeStatusNotificationManager nodeStatusNotificationManager; private final PlanCheckerProviderManager planCheckerProviderManager; + private final ExpressionOptimizerManager expressionOptimizerManager; @Inject public PluginManager( @@ -157,7 +160,8 @@ public PluginManager( HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, TracerProviderManager tracerProviderManager, NodeStatusNotificationManager nodeStatusNotificationManager, - PlanCheckerProviderManager planCheckerProviderManager) + PlanCheckerProviderManager planCheckerProviderManager, + ExpressionOptimizerManager expressionOptimizerManager) { requireNonNull(nodeInfo, "nodeInfo is null"); requireNonNull(config, "config is null"); @@ -190,6 +194,7 @@ public PluginManager( this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null"); this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null"); this.planCheckerProviderManager = requireNonNull(planCheckerProviderManager, "planCheckerProviderManager is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionManager is null"); } public void loadPlugins() @@ -372,6 +377,11 @@ public void installCoordinatorPlugin(CoordinatorPlugin plugin) log.info("Registering plan checker provider factory %s", planCheckerProviderFactory.getName()); planCheckerProviderManager.addPlanCheckerProviderFactory(planCheckerProviderFactory); } + + for (ExpressionOptimizerFactory expressionOptimizerFactory : plugin.getExpressionOptimizerFactories()) { + log.info("Registering expression optimizer factory %s", expressionOptimizerFactory.getName()); + expressionOptimizerManager.addExpressionOptimizerFactory(expressionOptimizerFactory); + } } private URLClassLoader buildClassLoader(String plugin) diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index d8a0fe050e7a..105e740636d2 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -51,6 +51,7 @@ import com.facebook.presto.server.security.PasswordAuthenticatorManager; import com.facebook.presto.server.security.ServerSecurityModule; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; import com.facebook.presto.storage.TempStorageManager; @@ -190,6 +191,8 @@ public void run() PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()); planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager); + injector.getInstance(ExpressionOptimizerManager.class).loadExpressionOptimizerFactory(); + startAssociatedProcesses(injector); injector.getInstance(Announcer.class).start(); diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index e2ab95982852..881588658003 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -194,6 +194,7 @@ import com.facebook.presto.sql.analyzer.MetadataExtractorMBean; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -359,6 +360,9 @@ else if (serverConfig.isCoordinator()) { binder.bind(SystemSessionProperties.class).in(Scopes.SINGLETON); binder.bind(SessionPropertyDefaults.class).in(Scopes.SINGLETON); + // expression manager + binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 034763c47212..e853a2846106 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -71,6 +71,7 @@ import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; @@ -168,6 +169,7 @@ public class TestingPrestoServer private final TaskManager taskManager; private final GracefulShutdownHandler gracefulShutdownHandler; private final ShutdownAction shutdownAction; + private final ExpressionOptimizerManager expressionManager; private final RequestBlocker requestBlocker; private final boolean resourceManager; private final boolean catalogServer; @@ -368,6 +370,7 @@ public TestingPrestoServer( procedureTester = injector.getInstance(ProcedureTester.class); splitManager = injector.getInstance(SplitManager.class); pageSourceManager = injector.getInstance(PageSourceManager.class); + expressionManager = injector.getInstance(ExpressionOptimizerManager.class); if (coordinator) { dispatchManager = injector.getInstance(DispatchManager.class); queryManager = injector.getInstance(QueryManager.class); @@ -385,6 +388,7 @@ public TestingPrestoServer( eventListenerManager = ((TestingEventListenerManager) injector.getInstance(EventListenerManager.class)); clusterStateProvider = null; planCheckerProviderManager = injector.getInstance(PlanCheckerProviderManager.class); + expressionManager.loadExpressionOptimizerFactory(); } else if (resourceManager) { dispatchManager = null; @@ -704,6 +708,11 @@ public ShutdownAction getShutdownAction() return shutdownAction; } + public ExpressionOptimizerManager getExpressionManager() + { + return expressionManager; + } + public boolean isCoordinator() { return coordinator; diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index f7ea9174fe8e..9c45e84f3318 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -279,6 +279,8 @@ public class FeaturesConfig private boolean generateDomainFilters; private boolean printEstimatedStatsFromCache; private boolean removeCrossJoinWithSingleConstantRow = true; + private boolean delegatingRowOptimizerEnabled; + private int delegatingRowOptimizerMaxIterations = 10; private CreateView.Security defaultViewSecurityMode = DEFINER; private boolean useHistograms; @@ -2875,4 +2877,31 @@ public FeaturesConfig setSingleNodeExecutionEnabled(boolean singleNodeExecutionE this.singleNodeExecutionEnabled = singleNodeExecutionEnabled; return this; } + + public boolean isDelegatingRowExpressionOptimizerEnabled() + { + return delegatingRowOptimizerEnabled; + } + + @Config("optimizer.delegating-row-expression-optimizer-enabled") + @ConfigDescription("Enable delegating row optimizer") + public FeaturesConfig setDelegatingRowExpressionOptimizerEnabled(boolean delegatingRowOptimizerEnabled) + { + this.delegatingRowOptimizerEnabled = delegatingRowOptimizerEnabled; + return this; + } + + @Min(1) + public int getDelegatingRowExpressionOptimizerMaxIterations() + { + return delegatingRowOptimizerMaxIterations; + } + + @Config("optimizer.delegating-row-expression-optimizer-max-iterations") + @ConfigDescription("Maximum number of iterations for delegating row optimizer") + public FeaturesConfig setDelegatingRowExpressionOptimizerMaxIterations(int delegatingRowOptimizerMaxIterations) + { + this.delegatingRowOptimizerMaxIterations = delegatingRowOptimizerMaxIterations; + return this; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java new file mode 100644 index 000000000000..303a5e2abb29 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; + +import javax.inject.Inject; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.presto.util.PropertiesUtil.loadProperties; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ExpressionOptimizerManager + implements ExpressionOptimizerProvider +{ + private static final File EXPRESSION_MANAGER_CONFIGURATION = new File("etc/expression-manager.properties"); + public static final String EXPRESSION_MANAGER_FACTORY_NAME = "expression-manager-factory.name"; + + private final Map expressionOptimizerFactories = new ConcurrentHashMap<>(); + private final AtomicReference rowExpressionInterpreter = new AtomicReference<>(); + private final NodeManager nodeManager; + private final FunctionAndTypeManager functionAndTypeManager; + private final FunctionResolution functionResolution; + private final ExpressionOptimizer defaultExpressionOptimizer; + + @Inject + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) + { + requireNonNull(nodeManager, "nodeManager is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + this.defaultExpressionOptimizer = new RowExpressionOptimizer(functionAndTypeManager); + rowExpressionInterpreter.set(defaultExpressionOptimizer); + } + + public void loadExpressionOptimizerFactory() + { + try { + if (EXPRESSION_MANAGER_CONFIGURATION.exists()) { + Map properties = new HashMap<>(loadProperties(EXPRESSION_MANAGER_CONFIGURATION)); + loadExpressionOptimizerFactory(properties); + } + } + catch (IOException e) { + throw new UncheckedIOException("Failed to load expression manager configuration", e); + } + } + + private void loadExpressionOptimizerFactory(Map properties) + { + properties = new HashMap<>(properties); + String factoryName = properties.remove(EXPRESSION_MANAGER_FACTORY_NAME); + checkArgument(!isNullOrEmpty(factoryName), "%s does not contain %s", EXPRESSION_MANAGER_CONFIGURATION, EXPRESSION_MANAGER_FACTORY_NAME); + checkArgument( + rowExpressionInterpreter.compareAndSet( + defaultExpressionOptimizer, + expressionOptimizerFactories.get(factoryName).createOptimizer(properties, new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution))), + "ExpressionManager is already loaded"); + } + + public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory) + { + String name = expressionOptimizerFactory.getName(); + checkArgument( + this.expressionOptimizerFactories.putIfAbsent(name, expressionOptimizerFactory) == null, + "ExpressionOptimizerFactory %s is already registered", name); + } + + @Override + public ExpressionOptimizer getExpressionOptimizer() + { + return rowExpressionInterpreter.get(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a900bb364d16..ab4b87ed9ab0 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -22,6 +22,7 @@ import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; @@ -220,7 +221,8 @@ public PlanOptimizers( CostComparator costComparator, TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, - FeaturesConfig featuresConfig) + FeaturesConfig featuresConfig, + ExpressionOptimizerManager expressionOptimizerManager) { this(metadata, sqlParser, @@ -235,7 +237,8 @@ public PlanOptimizers( costComparator, taskCountEstimator, partitioningProviderManager, - featuresConfig); + featuresConfig, + expressionOptimizerManager); } @PostConstruct @@ -266,7 +269,8 @@ public PlanOptimizers( CostComparator costComparator, TaskCountEstimator taskCountEstimator, PartitioningProviderManager partitioningProviderManager, - FeaturesConfig featuresConfig) + FeaturesConfig featuresConfig, + ExpressionOptimizerManager expressionOptimizerManager) { this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -321,7 +325,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.>builder() - .addAll(new SimplifyRowExpressions(metadata).rules()) + .addAll(new SimplifyRowExpressions(metadata, expressionOptimizerManager, featuresConfig).rules()) .add(new PruneRedundantProjectionAssignments()) .build()); @@ -487,7 +491,7 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.>builder() .add(new InlineProjectionsOnValues(metadata.getFunctionAndTypeManager())) - .addAll(new SimplifyRowExpressions(metadata).rules()) + .addAll(new SimplifyRowExpressions(metadata, expressionOptimizerManager, featuresConfig).rules()) .build()), new IterativeOptimizer( metadata, @@ -846,7 +850,7 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges - builder.add(new CteProjectionAndPredicatePushDown(metadata)); // must run before PhysicalCteOptimizer + builder.add(new CteProjectionAndPredicatePushDown(metadata, expressionOptimizerManager, featuresConfig)); // must run before PhysicalCteOptimizer builder.add(new PhysicalCteOptimizer(metadata)); // Must run before AddExchanges builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, partitioningProviderManager, featuresConfig.isNativeExecutionEnabled()))); builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchangesForSingleNodeExecution(metadata))); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java index 9705f7e836fd..9127beacb562 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyRowExpressions.java @@ -13,22 +13,27 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.DelegatingRowExpressionOptimizer; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.relational.RowExpressionOptimizer; import com.google.common.annotations.VisibleForTesting; +import static com.facebook.presto.SystemSessionProperties.isDelegatingRowExpressionOptimizerEnabled; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; @@ -39,44 +44,52 @@ public class SimplifyRowExpressions extends RowExpressionRewriteRuleSet { - public SimplifyRowExpressions(Metadata metadata) + public SimplifyRowExpressions(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager, FeaturesConfig featuresConfig) { - super(new Rewriter(metadata)); + super(new Rewriter(metadata, expressionOptimizerManager, featuresConfig)); } private static class Rewriter implements PlanRowExpressionRewriter { - private final RowExpressionOptimizer optimizer; + private final ExpressionOptimizer inMemoryExpressionOptimizer; + private final ExpressionOptimizer delegatingExpressionOptimizer; private final LogicalExpressionRewriter logicalExpressionRewriter; - public Rewriter(Metadata metadata) + public Rewriter(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager, FeaturesConfig featuresConfig) { requireNonNull(metadata, "metadata is null"); - this.optimizer = new RowExpressionOptimizer(metadata); + requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + requireNonNull(featuresConfig, "featuresConfig is null"); + this.inMemoryExpressionOptimizer = new RowExpressionOptimizer(metadata); + this.delegatingExpressionOptimizer = new DelegatingRowExpressionOptimizer(metadata, expressionOptimizerManager, featuresConfig.getDelegatingRowExpressionOptimizerMaxIterations()); this.logicalExpressionRewriter = new LogicalExpressionRewriter(metadata.getFunctionAndTypeManager()); } @Override public RowExpression rewrite(RowExpression expression, Rule.Context context) { - return rewrite(expression, context.getSession().toConnectorSession()); + return rewrite(expression, context.getSession()); } - private RowExpression rewrite(RowExpression expression, ConnectorSession session) + private RowExpression rewrite(RowExpression expression, Session session) { // Rewrite RowExpression first to reduce depth of RowExpression tree by balancing AND/OR predicates. // It doesn't matter whether we rewrite/optimize first because this will be called by IterativeOptimizer. RowExpression rewritten = RowExpressionTreeRewriter.rewriteWith(logicalExpressionRewriter, expression, true); - RowExpression optimizedRowExpression = optimizer.optimize(rewritten, SERIALIZABLE, session); - return optimizedRowExpression; + if (isDelegatingRowExpressionOptimizerEnabled(session)) { + return delegatingExpressionOptimizer.optimize(rewritten, SERIALIZABLE, session.toConnectorSession()); + } + else { + return inMemoryExpressionOptimizer.optimize(rewritten, SERIALIZABLE, session.toConnectorSession()); + } } } @VisibleForTesting - public static RowExpression rewrite(RowExpression expression, Metadata metadata, ConnectorSession session) + public static RowExpression rewrite(RowExpression expression, Metadata metadata, Session session, ExpressionOptimizerManager expressionOptimizerManager, FeaturesConfig featuresConfig) { - return new Rewriter(metadata).rewrite(expression, session); + return new Rewriter(metadata, expressionOptimizerManager, featuresConfig).rewrite(expression, session); } private static class LogicalExpressionRewriter diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java index d64068aeb3c2..cb706a0d2c01 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/CteProjectionAndPredicatePushDown.java @@ -26,6 +26,8 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.RowExpressionVariableInliner; import com.facebook.presto.sql.planner.SimplePlanVisitor; @@ -87,10 +89,14 @@ public class CteProjectionAndPredicatePushDown implements PlanOptimizer { private final Metadata metadata; + private final ExpressionOptimizerManager expressionOptimizerManager; + private final FeaturesConfig featuresConfig; - public CteProjectionAndPredicatePushDown(Metadata metadata) + public CteProjectionAndPredicatePushDown(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager, FeaturesConfig featuresConfig) { - this.metadata = metadata; + this.metadata = requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + this.featuresConfig = requireNonNull(featuresConfig, "featuresConfig is null"); } @Override @@ -383,7 +389,7 @@ private PlanNode addFilter(PlanNode node, List predicates) resultPredicate, predicates.get(i)); } } - resultPredicate = SimplifyRowExpressions.rewrite(resultPredicate, metadata, session.toConnectorSession()); + resultPredicate = SimplifyRowExpressions.rewrite(resultPredicate, metadata, session, expressionOptimizerManager, featuresConfig); return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), node, resultPredicate); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java new file mode 100644 index 000000000000..b420adc9ed18 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/DelegatingRowExpressionOptimizer.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; + +import java.util.function.Function; + +import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public final class DelegatingRowExpressionOptimizer + implements ExpressionOptimizer +{ + private static final int DEFAULT_MAX_OPTIMIZATION_ATTEMPTS = 10; + private final ExpressionOptimizerProvider expressionOptimizerManager; + private final int maxOptimizationAttempts; + + public DelegatingRowExpressionOptimizer(Metadata metadata, ExpressionOptimizerProvider expressionOptimizerManager) + { + this(metadata, expressionOptimizerManager, DEFAULT_MAX_OPTIMIZATION_ATTEMPTS); + } + + public DelegatingRowExpressionOptimizer(Metadata metadata, ExpressionOptimizerProvider expressionOptimizerManager, int maxOptimizationAttempts) + { + requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + checkArgument(maxOptimizationAttempts > 0, "maxOptimizationAttempts must be greater than 0"); + this.maxOptimizationAttempts = maxOptimizationAttempts; + } + + @Override + public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) + { + ExpressionOptimizer delegate = expressionOptimizerManager.getExpressionOptimizer(); + RowExpression originalExpression; + for (int i = 0; i < maxOptimizationAttempts; i++) { + // Do not optimize ConstantExpression, and InputReferenceExpression because they cannot be optimized further + if (rowExpression instanceof ConstantExpression || rowExpression instanceof InputReferenceExpression) { + return rowExpression; + } + originalExpression = rowExpression; + rowExpression = delegate.optimize(rowExpression, level, session); + requireNonNull(rowExpression, "optimized expression is null"); + if (originalExpression.equals(rowExpression)) { + break; + } + } + return rowExpression; + } + + @Override + public Object optimize(RowExpression rowExpression, Level level, ConnectorSession session, Function variableResolver) + { + ExpressionOptimizer delegate = expressionOptimizerManager.getExpressionOptimizer(); + Object currentExpression = rowExpression; + Object originalExpression; + for (int i = 0; i < maxOptimizationAttempts; i++) { + // Do not optimize ConstantExpression, and InputReferenceExpression because they cannot be optimized further + if (currentExpression instanceof ConstantExpression || currentExpression instanceof InputReferenceExpression) { + return currentExpression; + } + originalExpression = currentExpression; + currentExpression = delegate.optimize(toRowExpression(currentExpression, rowExpression.getType()), level, session, variableResolver); + if (currentExpression == null || currentExpression.equals(originalExpression)) { + break; + } + } + return currentExpression; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java index 1ab9d2ead2e6..f4a7a067f780 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -24,23 +26,30 @@ import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; public final class RowExpressionOptimizer implements ExpressionOptimizer { - private final Metadata metadata; + private final FunctionAndTypeManager functionAndTypeManager; public RowExpressionOptimizer(Metadata metadata) { - this.metadata = requireNonNull(metadata, "metadata is null"); + this(requireNonNull(metadata, "metadata is null").getFunctionAndTypeManager()); + } + + public RowExpressionOptimizer(FunctionMetadataManager functionMetadataManager) + { + checkArgument(functionMetadataManager instanceof FunctionAndTypeManager, "Expected functionMetadataManager to be instance of FunctionAndTypeManager"); + this.functionAndTypeManager = (FunctionAndTypeManager) requireNonNull(functionMetadataManager, "functionMetadataManager is null"); } @Override public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) { if (level.ordinal() <= OPTIMIZED.ordinal()) { - return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, metadata.getFunctionAndTypeManager(), session, level).optimize(), rowExpression.getType()); + return toRowExpression(rowExpression.getSourceLocation(), new RowExpressionInterpreter(rowExpression, functionAndTypeManager, session, level).optimize(), rowExpression.getType()); } throw new IllegalArgumentException("Not supported optimization level: " + level); } @@ -48,7 +57,7 @@ public RowExpression optimize(RowExpression rowExpression, Level level, Connecto @Override public Object optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) { - RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, metadata.getFunctionAndTypeManager(), session, level); + RowExpressionInterpreter interpreter = new RowExpressionInterpreter(expression, functionAndTypeManager, session, level); return interpreter.optimize(variableResolver::apply); } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index cefcd193422d..248bd6a11ec0 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -106,6 +106,7 @@ import com.facebook.presto.metadata.SchemaPropertyManager; import com.facebook.presto.metadata.Split; import com.facebook.presto.metadata.TablePropertyManager; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.Driver; import com.facebook.presto.operator.DriverContext; import com.facebook.presto.operator.DriverFactory; @@ -168,6 +169,7 @@ import com.facebook.presto.sql.analyzer.JavaFeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -353,6 +355,7 @@ public class LocalQueryRunner private static ExecutorService metadataExtractorExecutor = newCachedThreadPool(threadsNamed("query-execution-%s")); private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private ExpressionOptimizerManager expressionOptimizerManager; private List additionalOptimizer = ImmutableList.of(); @@ -450,8 +453,12 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.planFragmenter = new PlanFragmenter(this.metadata, this.nodePartitioningManager, new QueryManagerConfig(), featuresConfig, planCheckerProviderManager); this.joinCompiler = new JoinCompiler(metadata); this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); + + NodeInfo nodeInfo = new NodeInfo("test"); + expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager()); + this.statsNormalizer = new StatsNormalizer(); - this.scalarStatsCalculator = new ScalarStatsCalculator(metadata); + this.scalarStatsCalculator = new ScalarStatsCalculator(metadata, expressionOptimizerManager); this.filterStatsCalculator = new FilterStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer); this.historyBasedPlanStatisticsManager = new HistoryBasedPlanStatisticsManager(objectMapper, createTestingSessionPropertyManager(), metadata, new HistoryBasedOptimizationConfig(), featuresConfig, new NodeVersion("1")); this.fragmentStatsProvider = new FragmentStatsProvider(); @@ -466,7 +473,6 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.expressionCompiler = new ExpressionCompiler(metadata, pageFunctionCompiler); this.joinFilterFunctionCompiler = new JoinFilterFunctionCompiler(metadata); - NodeInfo nodeInfo = new NodeInfo("test"); NodeVersion nodeVersion = new NodeVersion("testversion"); this.connectorManager = new ConnectorManager( metadata, @@ -529,7 +535,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, historyBasedPlanStatisticsManager, new TracerProviderManager(new TracingConfig()), new NodeStatusNotificationManager(), - planCheckerProviderManager); + planCheckerProviderManager, + expressionOptimizerManager); connectorManager.addConnectorFactory(globalSystemConnectorFactory); connectorManager.createConnection(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of()); @@ -707,6 +714,12 @@ public TestingAccessControlManager getAccessControl() return accessControl; } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return expressionOptimizerManager; + } + public ExecutorService getExecutor() { return notificationExecutor; @@ -1127,7 +1140,8 @@ public List getPlanOptimizers(boolean noExchange) new CostComparator(featuresConfig), taskCountEstimator, partitioningProviderManager, - featuresConfig).getPlanningTimeOptimizers()); + featuresConfig, + expressionOptimizerManager).getPlanningTimeOptimizers()); return planOptimizers.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java index d2121ce335f8..e62447a403e9 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/QueryRunner.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.Plan; @@ -63,6 +64,8 @@ public interface QueryRunner TestingAccessControlManager getAccessControl(); + ExpressionOptimizerManager getExpressionManager(); + MaterializedResult execute(@Language("SQL") String sql); MaterializedResult execute(Session session, @Language("SQL") String sql); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ff8154866a60..c1446678a48c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -248,7 +248,9 @@ public void testDefaults() .setEagerPlanValidationEnabled(false) .setEagerPlanValidationThreadPoolSize(20) .setPrestoSparkExecutionEnvironment(false) - .setSingleNodeExecutionEnabled(false)); + .setSingleNodeExecutionEnabled(false) + .setDelegatingRowExpressionOptimizerEnabled(false) + .setDelegatingRowExpressionOptimizerMaxIterations(10)); } @Test @@ -446,6 +448,8 @@ public void testExplicitPropertyMappings() .put("eager-plan-validation-thread-pool-size", "2") .put("presto-spark-execution-environment", "true") .put("single-node-execution-enabled", "true") + .put("optimizer.delegating-row-expression-optimizer-enabled", "true") + .put("optimizer.delegating-row-expression-optimizer-max-iterations", "5") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -640,7 +644,9 @@ public void testExplicitPropertyMappings() .setEagerPlanValidationEnabled(true) .setEagerPlanValidationThreadPoolSize(2) .setPrestoSparkExecutionEnvironment(true) - .setSingleNodeExecutionEnabled(true); + .setSingleNodeExecutionEnabled(true) + .setDelegatingRowExpressionOptimizerEnabled(true) + .setDelegatingRowExpressionOptimizerMaxIterations(5); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index d3c644b793ca..f44809325bc0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -14,10 +14,18 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.functionNamespace.FunctionNamespaceManagerPlugin; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; @@ -76,8 +84,12 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.common.predicate.Domain.singleValue; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; @@ -88,8 +100,6 @@ import static com.facebook.presto.spi.plan.JoinType.RIGHT; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; -import static com.facebook.presto.sql.TestExpressionInterpreter.AVG_UDAF_CPP; -import static com.facebook.presto.sql.TestExpressionInterpreter.SQUARE_UDF_CPP; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; @@ -144,6 +154,26 @@ public class TestLogicalPlanner extends BasePlanTest { + public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "Integer square", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned()); + + public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), + parseTypeSignature(StandardTypes.DOUBLE), + "Returns mean of doubles", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned(), + FunctionKind.AGGREGATE, + Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); + // TODO: Use com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder#tableScan with required node/stream // partitioning to properly test aggregation, window function and join. diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index b4ad658409b8..ddb6b0d64767 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -16,13 +16,17 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -170,7 +174,12 @@ private List getMinimalOptimizers() new RuleStatsRecorder(), queryRunner.getStatsCalculator(), queryRunner.getCostCalculator(), - new SimplifyRowExpressions(metadata).rules())); + new SimplifyRowExpressions( + metadata, + new ExpressionOptimizerManager( + new PluginNodeManager(new InMemoryNodeManager()), + queryRunner.getFunctionAndTypeManager()), + new FeaturesConfig()).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java index 19638181d35e..8519c3e7fddb 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java @@ -14,12 +14,15 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; @@ -33,12 +36,24 @@ public class TestRemoveMapCastRule extends BaseRuleTest { - @Test - public void testSubscriptCast() + @DataProvider(name = "delegating-row-expression-optimizer-enabled") + public Object[][] delegatingDataProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testSubscriptCast(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RemoveMapCastRule(getFunctionManager()).rules()) + .build()) .setSystemProperty(REMOVE_MAP_CAST, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", DOUBLE); VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); @@ -53,12 +68,16 @@ public void testSubscriptCast() values("feature", "key"))); } - @Test - public void testElementAtCast() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testElementAtCast(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RemoveMapCastRule(getFunctionManager()).rules()) + .build()) .setSystemProperty(REMOVE_MAP_CAST, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", DOUBLE); VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java index 47c50c31a730..69eb3b69913f 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteConstantArrayContainsToInExpression.java @@ -15,12 +15,15 @@ import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.SystemSessionProperties.REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -32,13 +35,25 @@ public class TestRewriteConstantArrayContainsToInExpression extends BaseRuleTest { - @Test - public void testNoNull() + @DataProvider(name = "delegating-row-expression-optimizer-enabled") + public Object[][] delegatingDataProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testNoNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( - new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()) + .build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); VariableReferenceExpression b = p.variable("b"); @@ -52,11 +67,12 @@ public void testNoNull() values("b"))); } - @Test - public void testDoesNotFireForNestedArray() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForNestedArray(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, Boolean.toString(enableDelegatingRowExpressionOptimizer)) .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); VariableReferenceExpression b = p.variable("b", new ArrayType(BIGINT)); @@ -67,8 +83,8 @@ public void testDoesNotFireForNestedArray() .doesNotFire(); } - @Test - public void testDoesNotFireForNull() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") @@ -82,8 +98,8 @@ public void testDoesNotFireForNull() .doesNotFire(); } - @Test - public void testDoesNotFireForEmpty() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testDoesNotFireForEmpty(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).projectRowExpressionRewriteRule()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") @@ -97,12 +113,14 @@ public void testDoesNotFireForEmpty() .doesNotFire(); } - @Test - public void testNotFire() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testNotFire(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( - new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()) + .build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); @@ -118,12 +136,14 @@ public void testNotFire() values("b", "c"))); } - @Test - public void testWithNull() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testWithNull(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( - new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()) + .build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); @@ -138,12 +158,14 @@ public void testWithNull() values("b"))); } - @Test - public void testLambda() + @Test(dataProvider = "delegating-row-expression-optimizer-enabled") + public void testLambda(boolean enableDelegatingRowExpressionOptimizer) { tester().assertThat( - ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata()).rules()).addAll( - new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()).build()) + ImmutableSet.>builder() + .addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager(), new FeaturesConfig()).rules()) + .addAll(new RewriteConstantArrayContainsToInExpression(getFunctionManager()).rules()) + .build()) .setSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, "true") .on(p -> { VariableReferenceExpression a = p.variable("a", BOOLEAN); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 96da619e3ccb..d81aa041bc24 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -13,15 +13,20 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -36,6 +41,7 @@ import java.util.stream.Stream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.SystemSessionProperties.DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; @@ -43,6 +49,8 @@ import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.relational.Expressions.specialForm; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static java.lang.String.format; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; @@ -181,12 +189,24 @@ private static void assertSimplifies(String expression, String rowExpressionExpe { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); + InMemoryNodeManager nodeManager = new InMemoryNodeManager(); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager()); + expressionOptimizerManager.loadExpressionOptimizerFactory(); + TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES)); - RowExpression simplifiedRowExpression = SimplifyRowExpressions.rewrite(actualRowExpression, METADATA, TEST_SESSION.toConnectorSession()); + RowExpression simplifiedRowExpression = SimplifyRowExpressions.rewrite(actualRowExpression, METADATA, TEST_SESSION, expressionOptimizerManager, new FeaturesConfig()); Expression expectedByRowExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(rowExpressionExpected)); RowExpression simplifiedByExpression = translator.translate(expectedByRowExpression, TypeProvider.viewOf(TYPES)); assertEquals(normalize(simplifiedRowExpression), normalize(simplifiedByExpression)); + + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty(DELEGATING_ROW_EXPRESSION_OPTIMIZER_ENABLED, "true") + .build(); + RowExpression sidecarSimplifiedExpressions = SimplifyRowExpressions.rewrite(actualRowExpression, METADATA, session, expressionOptimizerManager, new FeaturesConfig()); + assertEquals(normalize(sidecarSimplifiedExpressions), normalize(simplifiedByExpression)); } private static RowExpression normalize(RowExpression expression) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java index 04ce733e96aa..0dafb0c96b8b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/BaseRuleTest.java @@ -16,6 +16,7 @@ import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.Plan; import com.google.common.collect.ImmutableList; import org.testng.annotations.AfterClass; @@ -84,4 +85,9 @@ protected void assertNodePresentInPlan(Plan plan, Class nodeClass) .matches(), "Expected " + nodeClass.toString() + " in plan after optimization. "); } + + protected ExpressionOptimizerManager getExpressionManager() + { + return tester.getExpressionManager(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java index 3c9540ef42b2..8783a6125e1e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleTester.java @@ -26,6 +26,7 @@ import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.OptimizerAssert; @@ -61,6 +62,7 @@ public class RuleTester private final PageSourceManager pageSourceManager; private final AccessControl accessControl; private final SqlParser sqlParser; + private ExpressionOptimizerManager expressionOptimizerManager; public RuleTester() { @@ -107,6 +109,8 @@ public RuleTester(List plugins, Map sessionProperties, S connectorFactory, ImmutableMap.of()); plugins.stream().forEach(queryRunner::installPlugin); + expressionOptimizerManager = queryRunner.getExpressionManager(); + expressionOptimizerManager.loadExpressionOptimizerFactory(); this.metadata = queryRunner.getMetadata(); this.transactionManager = queryRunner.getTransactionManager(); @@ -197,4 +201,9 @@ public List> getTableConstraints(TableHandle table return metadata.getTableMetadata(transactionSession, tableHandle).getMetadata().getTableConstraintsHolder().getTableConstraintsWithColumnHandles(); }); } + + public ExpressionOptimizerManager getExpressionManager() + { + return expressionOptimizerManager; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java index f36c3aaabcf1..0e73b6b3d551 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestCteProjectionAndPredicatePushdown.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; @@ -144,7 +145,7 @@ private void assertCtePlan(String sql, PlanMatchPattern pattern) new RemoveIdentityProjectionsBelowProjection(), new PruneRedundantProjectionAssignments())), new PruneUnreferencedOutputs(), - new CteProjectionAndPredicatePushDown(metadata)); + new CteProjectionAndPredicatePushDown(metadata, getQueryRunner().getExpressionManager(), new FeaturesConfig())); assertPlan(sql, getSession(), Optimizer.PlanStage.OPTIMIZED, pattern, optimizers); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java new file mode 100644 index 000000000000..4d1b3606068d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/relational/TestDelegatingRowExpressionOptimizer.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.sql.planner.LiteralEncoder.toRowExpression; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; + +public class TestDelegatingRowExpressionOptimizer +{ + private DelegatingRowExpressionOptimizer optimizer; + private static final MetadataManager METADATA = MetadataManager.createTestMetadataManager(); + + private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + + @BeforeClass + public void setUp() + { + optimizer = new DelegatingRowExpressionOptimizer(METADATA, InnerOptimizer::new, 3); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + optimizer = null; + } + + @Test + public void testBasicExpressions() + { + assertEquals(optimize(expression("BIGINT'1' + 1")), expression("BIGINT'2'")); + assertEquals(optimize(expression("IF(TRUE, 1, 2)")), expression("1")); + } + + @Test + public void testVariableReference() + { + VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), "x", BIGINT); + assertEquals(optimize(variable), variable); + ImmutableMap typeMap = ImmutableMap.of("x", BIGINT); + assertEquals(optimize(expression("x + 1", typeMap)), expression("x + 1", typeMap)); + assertEquals(optimize(expression("x + 1", typeMap), ImmutableMap.of(variable, 1L)), expression("BIGINT'2'")); + } + + @Test + public void testComplexExpressions() + { + assertEquals(optimize(expression("IF(TRUE, 1, 2) + 3")), expression("4")); + assertEquals(optimize(expression("IF(TRUE, 1, 2) + 3 + 4")), expression("8")); + assertEquals(optimize(expression("IF(TRUE, 1, 2) + 3 + 4 + 5")), expression("8 + 5")); + + VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), "x", INTEGER); + ImmutableMap typeMap = ImmutableMap.of("x", INTEGER); + assertEquals(optimize(expression("IF(TRUE, 1, 2) + x", typeMap), ImmutableMap.of(variable, 3L)), expression("4")); + assertEquals(optimize(expression("IF(TRUE, 1, 2) + x + 4", typeMap), ImmutableMap.of(variable, 3L)), expression("8")); + assertEquals(optimize(expression("IF(TRUE, 1, 2) + x + 4 + 5", typeMap), ImmutableMap.of(variable, 3L)), expression("8 + 5")); + } + + @Test + public void testDifferentOptimizationLevels() + { + assertEquals(optimize(expression("rand(10) + 1 > 0"), EVALUATED), expression("true")); + assertEquals(optimize(expression("rand(10) + 1 > 0"), OPTIMIZED), expression("rand(10) + 1 > 0")); + assertEquals(optimize(expression("rand(10) + 1 > 0"), SERIALIZABLE), expression("rand(10) + 1 > 0")); + } + + private Object optimize(RowExpression expression, Map variableMap) + { + return optimize(expression, variableMap, OPTIMIZED); + } + + private Object optimize(RowExpression expression, Map variableMap, Level level) + { + return optimizer.optimize(expression, level, SESSION, variableMap::get); + } + + private RowExpression optimize(RowExpression expression) + { + return optimize(expression, OPTIMIZED); + } + + private RowExpression optimize(RowExpression expression, Level level) + { + return optimizer.optimize(expression, level, SESSION); + } + + private static RowExpression expression(String expressionSql) + { + return expression(expressionSql, ImmutableMap.of()); + } + + private static RowExpression expression(String expressionSql, Map typeMap) + { + return TRANSLATOR.translate(expressionSql, typeMap); + } + + private static class InnerOptimizer + implements ExpressionOptimizer + { + @Override + public RowExpression optimize(RowExpression rowExpression, Level level, ConnectorSession session) + { + OneLevelDeepExpressionRewriter rewriter = new OneLevelDeepExpressionRewriter(level, variable -> variable); + return rowExpression.accept(rewriter, null); + } + + @Override + public Object optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) + { + OneLevelDeepExpressionRewriter rewriter = new OneLevelDeepExpressionRewriter(level, variableResolver); + return expression.accept(rewriter, null); + } + } + + // This visitor will only rewrite the first expression it comes across. It is intended to be used to test + // the DelegatingRowExpressionOptimizer, which will call the inner optimizer multiple times. + private static class OneLevelDeepExpressionRewriter + implements RowExpressionVisitor + { + private final RowExpressionOptimizer innerOptimizer = new RowExpressionOptimizer(METADATA); + private final Level level; + private final Function variableResolver; + + private boolean rewritten; + + public OneLevelDeepExpressionRewriter(Level level, Function variableResolver) + { + this.level = level; + this.variableResolver = requireNonNull(variableResolver, "variableResolver is null"); + } + + @Override + public RowExpression visitExpression(RowExpression node, Void context) + { + return node; + } + + @Override + public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context) + { + if (variableResolver == null) { + return reference; + } + Object value = variableResolver.apply(reference); + if (value == null) { + return reference; + } + return toRowExpression(value, reference.getType()); + } + + @Override + public RowExpression visitCall(CallExpression call, Void context) + { + List arguments = call.getArguments().stream() + .map(argument -> argument.accept(this, context)) + .collect(toImmutableList()); + if (!rewritten) { + RowExpression rewritten = toRowExpression(innerOptimizer.optimize(call, level, SESSION, variableResolver), call.getType()); + if (!rewritten.equals(call)) { + this.rewritten = true; + return rewritten; + } + } + return call(call.getDisplayName(), call.getFunctionHandle(), call.getType(), arguments); + } + + @Override + public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context) + { + List arguments = specialForm.getArguments().stream() + .map(argument -> argument.accept(this, context)) + .collect(toImmutableList()); + if (!rewritten) { + RowExpression rewritten = toRowExpression(innerOptimizer.optimize(specialForm, OPTIMIZED, SESSION, variableResolver), specialForm.getType()); + if (!rewritten.equals(specialForm)) { + this.rewritten = true; + return rewritten; + } + } + return new SpecialFormExpression(specialForm.getForm(), specialForm.getType(), arguments); + } + } +} diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java index d32eea55c392..c65afda86cf1 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/ContainerQueryRunner.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; @@ -230,6 +231,12 @@ public TestingAccessControlManager getAccessControl() throw new UnsupportedOperationException(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + throw new UnsupportedOperationException(); + } + @Override public MaterializedResult execute(String sql) { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index ed017e654cce..b84557391a8b 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -80,6 +80,7 @@ import com.facebook.presto.metadata.StaticFunctionNamespaceStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStoreConfig; import com.facebook.presto.metadata.TablePropertyManager; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.FileFragmentResultCacheConfig; import com.facebook.presto.operator.FileFragmentResultCacheManager; import com.facebook.presto.operator.FragmentCacheStats; @@ -173,6 +174,7 @@ import com.facebook.presto.sql.analyzer.MetadataExtractorMBean; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -346,6 +348,9 @@ protected void setup(Binder binder) binder.bind(AnalyzePropertyManager.class).in(Scopes.SINGLETON); binder.bind(QuerySessionSupplier.class).in(Scopes.SINGLETON); + // expression manager + binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + // tracer provider managers binder.bind(TracerProviderManager.class).in(Scopes.SINGLETON); @@ -508,6 +513,7 @@ protected void setup(Binder binder) // TODO: Decouple and remove: required by ConnectorManager binder.bind(InternalNodeManager.class).toInstance(new PrestoSparkInternalNodeManager()); + binder.bind(PluginNodeManager.class); // TODO: Decouple and remove: required by PluginManager binder.bind(InternalResourceGroupManager.class).in(Scopes.SINGLETON); diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java index f85c08ef58eb..2ddde52dfa79 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/PrestoSparkQueryRunner.java @@ -64,6 +64,7 @@ import com.facebook.presto.spi.security.PrincipalType; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -504,6 +505,12 @@ public TestingAccessControlManager getAccessControl() return testingAccessControlManager; } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + throw new UnsupportedOperationException(); + } + public HistoryBasedPlanStatisticsManager getHistoryBasedPlanStatisticsManager() { return historyBasedPlanStatisticsManager; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizerProvider.java b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizerProvider.java new file mode 100644 index 000000000000..3ffde03a0d6e --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/relation/ExpressionOptimizerProvider.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.relation; + +public interface ExpressionOptimizerProvider +{ + ExpressionOptimizer getExpressionOptimizer(); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/relation/RowExpressionService.java b/presto-spi/src/main/java/com/facebook/presto/spi/relation/RowExpressionService.java index 12710f71b0f7..0d2a5545c2e9 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/relation/RowExpressionService.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/relation/RowExpressionService.java @@ -19,6 +19,7 @@ * A set of services/utilities that are helpful for connectors to operate on row expressions */ public interface RowExpressionService + extends ExpressionOptimizerProvider { DomainTranslator getDomainTranslator(); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 303ae0dfe780..bedd15a5d9e1 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.tests; +import com.facebook.airlift.node.NodeInfo; import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.CostCalculator; @@ -21,11 +22,14 @@ import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.QueryExplainer; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.Plan; @@ -74,6 +78,7 @@ public abstract class AbstractTestQueryFramework { + private static final NodeInfo NODE_INFO = new NodeInfo("test"); private QueryRunner queryRunner; private ExpectedQueryRunner expectedQueryRunner; private SqlParser sqlParser; @@ -568,7 +573,10 @@ private QueryExplainer getQueryExplainer() new CostComparator(featuresConfig), taskCountEstimator, new PartitioningProviderManager(), - featuresConfig) + featuresConfig, + new ExpressionOptimizerManager( + new PluginNodeManager(new InMemoryNodeManager()), + queryRunner.getMetadata().getFunctionAndTypeManager())) .getPlanningTimeOptimizers(); return new QueryExplainer( optimizers, diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 62502743809b..5674d59e752b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -42,6 +42,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -626,6 +627,13 @@ public PlanCheckerProviderManager getPlanCheckerProviderManager() return coordinators.get(0).getPlanCheckerProviderManager(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + checkState(coordinators.size() == 1, "Expected a single coordinator"); + return coordinators.get(0).getExpressionManager(); + } + public TestingPrestoServer getCoordinator() { checkState(coordinators.size() == 1, "Expected a single coordinator"); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java index 980fb4f991bc..00993ae05495 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/StandaloneQueryRunner.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; @@ -188,6 +189,12 @@ public TestingAccessControlManager getAccessControl() return server.getAccessControl(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return server.getExpressionManager(); + } + public TestingPrestoServer getServer() { return server; diff --git a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java index 30951ad186cb..19d22733f67a 100644 --- a/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java +++ b/presto-tests/src/test/java/com/facebook/presto/memory/TestMemoryManager.java @@ -512,7 +512,7 @@ public void clusterPoolsMultiCoordinatorCleanup() queryRunner2.close(); } - @Test(timeOut = 60_000, groups = {"clusterPoolsMultiCoordinator"}) + @Test(timeOut = 600_000, groups = {"clusterPoolsMultiCoordinator"}) public void testClusterPoolsMultiCoordinator() throws Exception { @@ -544,6 +544,7 @@ public void testClusterPoolsMultiCoordinator() generalPool = memoryManager.getClusterInfo(GENERAL_POOL); reservedPool = memoryManager.getClusterInfo(RESERVED_POOL); MILLISECONDS.sleep(10); + System.out.println("waiting"); } // Make sure the queries are blocked diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java new file mode 100644 index 000000000000..616feccc3d7c --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestDelegatingExpressionOptimizer.java @@ -0,0 +1,146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.InMemoryExpressionOptimizerProvider; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.relational.DelegatingRowExpressionOptimizer; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.type.LikeFunctions.castVarcharToLikePattern; +import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static org.testng.Assert.assertEquals; + +public class TestDelegatingExpressionOptimizer + extends TestExpressions +{ + private FunctionResolution resolution; + private ExpressionOptimizer expressionOptimizer; + + @BeforeClass + public void setup() + { + expressionOptimizer = new DelegatingRowExpressionOptimizer(getMetadata(), new InMemoryExpressionOptimizerProvider(getMetadata())); + resolution = new FunctionResolution(getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); + } + + @Test + public void assertLikeOptimizations() + { + assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); + } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + CallExpression predicate = call( + "LIKE", + resolution.likeVarcharFunction(), + BOOLEAN, + ImmutableList.of( + new ConstantExpression(wrappedBuffer(value), VARCHAR), + new ConstantExpression(castVarcharToLikePattern(utf8Slice(pattern)), LIKE_PATTERN))); + assertEquals(optimizeRowExpression(predicate, EVALUATED), expected); + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + return optimizeRowExpression(rowExpression, EVALUATED); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, OPTIMIZED); + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, Level level) + { + Object optimized = expressionOptimizer.optimize( + expression, + level, + TEST_SESSION.toConnectorSession(), + variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + return unwrap(optimized); + } + + public Object unwrap(Object result) + { + if (result instanceof ConstantExpression) { + return ((ConstantExpression) result).getValue(); + } + else { + return result; + } + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object optimizedActual = optimize(actual); + Object optimizedExpected = optimize(expected); + assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object actualOptimized = optimize(actual); + Object expectedOptimized = optimize(expected); + assertRowExpressionEvaluationEquals( + actualOptimized, + expectedOptimized); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java new file mode 100644 index 000000000000..17534292b5c2 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionInterpreter.java @@ -0,0 +1,225 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.planner.ExpressionInterpreter; +import com.facebook.presto.sql.planner.RowExpressionInterpreter; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.LikePredicate; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.StringLiteral; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; +import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; +import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; +import static java.util.Collections.emptyMap; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestExpressionInterpreter + extends TestExpressions +{ + private final TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(getMetadata()); + + @Test + public void assertLikeOptimizations() + { + assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); + } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + Expression predicate = new LikePredicate( + rawStringLiteral(Slices.wrappedBuffer(value)), + new StringLiteral(pattern), + Optional.empty()); + assertEquals(evaluate(predicate, true), expected); + } + + private static StringLiteral rawStringLiteral(final Slice slice) + { + return new StringLiteral(slice.toStringUtf8()) + { + @Override + public Slice getSlice() + { + return slice; + } + }; + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertEquals(optimize(actual), optimize(expected)); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + // replaces FunctionCalls to FailureFunction by fail() + Object actualOptimized = optimize(actual); + if (actualOptimized instanceof Expression) { + actualOptimized = ExpressionTreeRewriter.rewriteWith(new FailedFunctionRewriter(), (Expression) actualOptimized); + } + assertEquals( + actualOptimized, + rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + + Expression parsedExpression = expression(expression); + Object expressionResult = optimize(parsedExpression); + + RowExpression rowExpression = toRowExpression(parsedExpression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, OPTIMIZED); + assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + return expressionResult; + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + RowExpressionInterpreter rowExpressionInterpreter = new RowExpressionInterpreter(expression, getMetadata(), TEST_SESSION.toConnectorSession(), level); + return rowExpressionInterpreter.optimize(variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + } + + private Expression expression(String expression) + { + return FunctionAssertions.createExpression(expression, getMetadata(), SYMBOL_TYPES); + } + + private RowExpression toRowExpression(Expression expression) + { + return translator.translate(expression, SYMBOL_TYPES); + } + + private Object optimize(Expression expression) + { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, getMetadata(), SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + ExpressionInterpreter interpreter = expressionOptimizer(expression, getMetadata(), TEST_SESSION, expressionTypes); + return interpreter.optimize(variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return symbol.toSymbolReference(); + } + return value; + }); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + Expression translatedExpression = expression(expression); + RowExpression rowExpression = toRowExpression(translatedExpression); + + Object expressionResult = optimize(translatedExpression); + if (expressionResult instanceof Expression) { + expressionResult = toRowExpression((Expression) expressionResult); + } + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } + + private void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) + { + if (rowExpressionResult instanceof RowExpression) { + // Cannot be completely evaluated into a constant; compare expressions + assertTrue(expressionResult instanceof Expression); + + // It is tricky to check the equivalence of an expression and a row expression. + // We rely on the optimized translator to fill the gap. + RowExpression translated = translator.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); + assertRowExpressionEvaluationEquals(translated, rowExpressionResult); + } + else { + // We have constants; directly compare + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + } + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + + Expression parsedExpression = FunctionAssertions.createExpression(expression, getMetadata(), SYMBOL_TYPES); + + return evaluate(parsedExpression, deterministic); + } + + private Object evaluate(Expression expression, boolean deterministic) + { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, getMetadata(), SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + Object expressionResult = expressionInterpreter(expression, getMetadata(), TEST_SESSION, expressionTypes).evaluate(); + Object rowExpressionResult = rowExpressionInterpreter(translator.translateAndOptimize(expression), getMetadata().getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession()).evaluate(); + + if (deterministic) { + assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + } + return expressionResult; + } + + private static class FailedFunctionRewriter + extends ExpressionRewriter + { + @Override + public Expression rewriteFunctionCall(FunctionCall node, Object context, ExpressionTreeRewriter treeRewriter) + { + if (node.getName().equals(QualifiedName.of("fail"))) { + return new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(node.getArguments().get(0), new StringLiteral("ignored failure message"))); + } + return node; + } + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java new file mode 100644 index 000000000000..48add5a17fd8 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressionOptimizers.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.expressions; + +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.google.common.collect.ImmutableList; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.type.LikeFunctions.castVarcharToLikePattern; +import static com.facebook.presto.type.LikePatternType.LIKE_PATTERN; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static org.testng.Assert.assertEquals; + +public class TestExpressionOptimizers + extends TestExpressions +{ + private final FunctionResolution resolution = new FunctionResolution(getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); + private ExpressionOptimizer expressionOptimizer; + + @BeforeClass + public void setup() + { + expressionOptimizer = new RowExpressionOptimizer(getMetadata().getFunctionAndTypeManager()); + } + + @Test + public void assertLikeOptimizations() + { + assertOptimizedMatches("unbound_string LIKE bound_pattern", "unbound_string LIKE CAST('%el%' AS varchar)"); + } + + @Override + protected void assertLike(byte[] value, String pattern, boolean expected) + { + CallExpression predicate = call( + "LIKE", + resolution.likeVarcharFunction(), + BOOLEAN, + ImmutableList.of( + new ConstantExpression(wrappedBuffer(value), VARCHAR), + new ConstantExpression(castVarcharToLikePattern(utf8Slice(pattern)), LIKE_PATTERN))); + assertEquals(optimizeRowExpression(predicate, EVALUATED), expected); + } + @Override + protected Object evaluate(String expression, boolean deterministic) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + return optimizeRowExpression(rowExpression, EVALUATED); + } + + @Override + protected Object optimize(@Language("SQL") String expression) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, OPTIMIZED); + } + + @Override + protected Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + return expressionOptimizer.optimize( + expression, + level, + TEST_SESSION.toConnectorSession(), + variable -> { + Symbol symbol = new Symbol(variable.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + return value; + }); + } + + @Override + protected void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object optimizedActual = optimize(actual); + Object optimizedExpected = optimize(expected); + assertEquals(optimizedActual, optimizedExpected); + } + + @Override + protected void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + Object actualOptimized = optimize(actual); + Object expectedOptimized = optimize(expected); + assertRowExpressionEvaluationEquals( + actualOptimized, + expectedOptimized); + } + + @Override + protected void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + Object rowExpressionResult = optimizeRowExpression(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java similarity index 87% rename from presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java rename to presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java index 37295146c429..a932a099c73e 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-tests/src/test/java/com/facebook/presto/tests/expressions/TestExpressions.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sql; +package com.facebook.presto.tests.expressions; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; @@ -24,14 +24,11 @@ import com.facebook.presto.common.type.SqlTimestampWithTimeZone; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.scalar.FunctionAssertions; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.AggregationFunctionMetadata; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Parameter; @@ -39,26 +36,19 @@ import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.InputReferenceExpression; import com.facebook.presto.spi.relation.LambdaDefinitionExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.parser.ParsingOptions; import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.planner.ExpressionInterpreter; -import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.ExpressionRewriter; -import com.facebook.presto.sql.tree.ExpressionTreeRewriter; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.LikePredicate; -import com.facebook.presto.sql.tree.NodeRef; -import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.sql.tree.StringLiteral; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -71,11 +61,9 @@ import org.joda.time.DateTimeZone; import org.joda.time.LocalDate; import org.joda.time.LocalTime; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.math.BigInteger; -import java.util.Map; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; @@ -91,37 +79,33 @@ import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; -import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; -import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; -import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; -import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; -import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.String.format; -import static java.util.Collections.emptyMap; import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; -public class TestExpressionInterpreter +public abstract class TestExpressions { - public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( + private static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), parseTypeSignature(StandardTypes.BIGINT), @@ -130,7 +114,7 @@ public class TestExpressionInterpreter "", notVersioned()); - public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( + private static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), parseTypeSignature(StandardTypes.DOUBLE), @@ -142,11 +126,11 @@ public class TestExpressionInterpreter Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); private static final int TEST_VARCHAR_TYPE_LENGTH = 17; - private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() + protected static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() .put("bound_integer", INTEGER) .put("bound_long", BIGINT) .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) - .put("bound_varbinary", VarbinaryType.VARBINARY) + .put("bound_varbinary", VARBINARY) .put("bound_double", DOUBLE) .put("bound_boolean", BOOLEAN) .put("bound_date", DATE) @@ -173,16 +157,20 @@ public class TestExpressionInterpreter .put("unbound_null_string", VARCHAR) .build()); - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); - private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); + protected static final SqlParser SQL_PARSER = new SqlParser(); + private final Metadata metadata = createTestMetadataManager(); + private final TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(metadata); + private static final BlockEncodingSerde BLOCK_ENCODING_SERDE = new BlockEncodingManager(); - @BeforeClass - public void setup() + public TestExpressions() { - METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); - setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); + metadata.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); + setupJsonFunctionNamespaceManager(metadata.getFunctionAndTypeManager()); + } + + public Metadata getMetadata() + { + return metadata; } @Test @@ -415,19 +403,19 @@ public void testNonDeterministicFunctionCall() @Test public void testCppFunctionCall() { - METADATA.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); + metadata.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); assertOptimizedEquals("json.test_schema.square(-5)", "json.test_schema.square(-5)"); } @Test public void testCppAggregateFunctionCall() { - METADATA.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); + metadata.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); assertOptimizedEquals("json.test_schema.avg(1.0)", "json.test_schema.avg(1.0)"); } // Run this method exactly once. - private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + protected void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) { functionAndTypeManager.addFunctionNamespaceFactory(new JsonFileBasedFunctionNamespaceManagerFactory()); functionAndTypeManager.loadFunctionNamespaceManager( @@ -1162,14 +1150,14 @@ public void testSimpleCase() "else 3 " + "end"); - assertOptimizedEquals("case true " + + assertOptimizedMatches("case true " + "when unbound_long = 1 then 1 " + "when 0 / 0 = 0 then 2 " + "else 33 end", "" + "case true " + - "when unbound_long = 1 then 1 " + - "when 0 / 0 = 0 then 2 else 33 " + + "when unbound_long = BIGINT '1' then 1 " + + "when CAST(fail(8, 'ignored failure message') AS boolean) then 2 else 33 " + "end"); assertOptimizedEquals("case bound_long " + @@ -1199,18 +1187,6 @@ public void testSimpleCase() "when unbound_long then 4 " + "end"); - assertOptimizedMatches("case 1 " + - "when unbound_long then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case BIGINT '1' " + - "when unbound_long then 1 " + - "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + - "else 1 " + - "end"); - assertOptimizedMatches("case 1 " + "when 0 / 0 then 1 " + "when 0 / 0 then 2 " + @@ -1394,16 +1370,15 @@ public void testLikeOptimization() assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); + assertOptimizedMatches("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); + assertOptimizedMatches("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); - assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); - assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); + assertOptimizedMatches("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); } @Test @@ -1586,123 +1561,27 @@ public void testLiterals() optimize("interval '3' day * unbound_long"); optimize("interval '3' year * unbound_long"); - assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); - } - - private static void assertLike(byte[] value, String pattern, boolean expected) - { - Expression predicate = new LikePredicate( - rawStringLiteral(Slices.wrappedBuffer(value)), - new StringLiteral(pattern), - Optional.empty()); - assertEquals(evaluate(predicate, true), expected); - } - - private static StringLiteral rawStringLiteral(final Slice slice) - { - return new StringLiteral(slice.toStringUtf8()) - { - @Override - public Slice getSlice() - { - return slice; - } - }; + assertEquals(optimize("X'1234'"), wrappedBuffer((byte) 0x12, (byte) 0x34)); } + protected abstract Object evaluate(String expression, boolean deterministic); - private static void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) - { - assertEquals(optimize(actual), optimize(expected)); - } - - private static void assertRowExpressionEquals(Level level, @Language("SQL") String actual, @Language("SQL") String expected) - { - Object actualResult = optimize(toRowExpression(expression(actual)), level); - Object expectedResult = optimize(toRowExpression(expression(expected)), level); - if (actualResult instanceof Block && expectedResult instanceof Block) { - assertEquals(blockToSlice((Block) actualResult), blockToSlice((Block) expectedResult)); - return; - } - assertEquals(actualResult, expectedResult); - } - - private static void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) - { - // replaces FunctionCalls to FailureFunction by fail() - Object actualOptimized = optimize(actual); - if (actualOptimized instanceof Expression) { - actualOptimized = ExpressionTreeRewriter.rewriteWith(new FailedFunctionRewriter(), (Expression) actualOptimized); - } - assertEquals( - actualOptimized, - rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); - } - - private static Object optimize(@Language("SQL") String expression) - { - assertRoundTrip(expression); - - Expression parsedExpression = expression(expression); - Object expressionResult = optimize(parsedExpression); - - RowExpression rowExpression = toRowExpression(parsedExpression); - Object rowExpressionResult = optimize(rowExpression, OPTIMIZED); - assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); - return expressionResult; - } + protected abstract Object optimize(@Language("SQL") String expression); - private static Expression expression(String expression) - { - return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - } + protected abstract void assertOptimizedEquals(@Language("SQL") String expression, @Language("SQL") String expected); - private static RowExpression toRowExpression(Expression expression) - { - return TRANSLATOR.translate(expression, SYMBOL_TYPES); - } + protected abstract void assertLike(byte[] value, String pattern, boolean expected); - private static Object optimize(Expression expression) - { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); - return interpreter.optimize(variable -> { - Symbol symbol = new Symbol(variable.getName()); - Object value = symbolConstant(symbol); - if (value == null) { - return symbol.toSymbolReference(); - } - return value; - }); - } + protected abstract void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected); - private static Object optimize(RowExpression expression, Level level) - { - return new RowExpressionInterpreter(expression, METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession(), level).optimize(variable -> { - Symbol symbol = new Symbol(variable.getName()); - Object value = symbolConstant(symbol); - if (value == null) { - return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); - } - return value; - }); - } + protected abstract void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel); - private static void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) + protected RowExpression sqlToRowExpression(String expression) { - assertRoundTrip(expression); - Expression translatedExpression = expression(expression); - RowExpression rowExpression = toRowExpression(translatedExpression); - - Object expressionResult = optimize(translatedExpression); - if (expressionResult instanceof Expression) { - expressionResult = toRowExpression((Expression) expressionResult); - } - Object rowExpressionResult = optimize(rowExpression, optimizationLevel); - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + Expression parsedExpression = FunctionAssertions.createExpression(expression, metadata, SYMBOL_TYPES); + return translator.translate(parsedExpression, SYMBOL_TYPES); } - private static Object symbolConstant(Symbol symbol) + protected Object symbolConstant(Symbol symbol) { switch (symbol.getName().toLowerCase(ENGLISH)) { case "bound_integer": @@ -1733,32 +1612,14 @@ private static Object symbolConstant(Symbol symbol) return null; } - private static void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) - { - if (rowExpressionResult instanceof RowExpression) { - // Cannot be completely evaluated into a constant; compare expressions - assertTrue(expressionResult instanceof Expression); - - // It is tricky to check the equivalence of an expression and a row expression. - // We rely on the optimized translator to fill the gap. - RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); - assertRowExpressionEvaluationEquals(translated, rowExpressionResult); - } - else { - // We have constants; directly compare - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - } - } - /** * Assert the evaluation result of two row expressions equivalent * no matter they are constants or remaining row expressions. */ - private static void assertRowExpressionEvaluationEquals(Object left, Object right) + protected void assertRowExpressionEvaluationEquals(Object left, Object right) { if (right instanceof RowExpression) { assertTrue(left instanceof RowExpression); - // assertEquals(((RowExpression) left).getType(), ((RowExpression) right).getType()); if (left instanceof ConstantExpression) { if (isRemovableCast(right)) { assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); @@ -1770,6 +1631,13 @@ private static void assertRowExpressionEvaluationEquals(Object left, Object righ else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { assertEquals(left, right); } + else if (left instanceof CallExpression && ((CallExpression) left).getFunctionHandle().getName().contains("fail")) { + assertTrue(right instanceof CallExpression && ((CallExpression) right).getFunctionHandle().getName().contains("fail")); + assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); + for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); + } + } else if (left instanceof CallExpression) { assertTrue(right instanceof CallExpression); assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); @@ -1806,68 +1674,46 @@ else if (left instanceof SpecialFormExpression) { } } - private static boolean isRemovableCast(Object value) + private boolean isRemovableCast(Object value) { if (value instanceof CallExpression && - new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { + new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { Type targetType = ((CallExpression) value).getType(); Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); - return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); + return metadata.getFunctionAndTypeManager().canCoerce(sourceType, targetType); } return false; } - private static Slice blockToSlice(Block block) + protected Slice blockToSlice(Block block) { // This function is strictly for testing use only SliceOutput sliceOutput = new DynamicSliceOutput(1000); - BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); + BlockSerdeUtil.writeBlock(BLOCK_ENCODING_SERDE, sliceOutput, block); return sliceOutput.slice(); } - private static void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + protected void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) { assertEquals(evaluate(actual, true), evaluate(expected, true)); } - private static Object evaluate(String expression, boolean deterministic) - { - assertRoundTrip(expression); - - Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - - return evaluate(parsedExpression, deterministic); - } - - private static void assertRoundTrip(String expression) + protected void assertRoundTrip(String expression) { ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); } - - private static Object evaluate(Expression expression, boolean deterministic) + protected void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate(); - Object rowExpressionResult = rowExpressionInterpreter(TRANSLATOR.translateAndOptimize(expression), METADATA.getFunctionAndTypeManager(), TEST_SESSION.toConnectorSession()).evaluate(); - - if (deterministic) { - assertExpressionAndRowExpressionEquals(expressionResult, rowExpressionResult); + Object actualResult = optimizeRowExpression(sqlToRowExpression(actual), level); + Object expectedResult = optimizeRowExpression(sqlToRowExpression(expected), level); + if (actualResult instanceof Block && expectedResult instanceof Block) { + assertEquals(blockToSlice((Block) actualResult), blockToSlice((Block) expectedResult)); + return; } - return expressionResult; + assertEquals(actualResult, expectedResult); } - private static class FailedFunctionRewriter - extends ExpressionRewriter - { - @Override - public Expression rewriteFunctionCall(FunctionCall node, Object context, ExpressionTreeRewriter treeRewriter) - { - if (node.getName().equals(QualifiedName.of("fail"))) { - return new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(node.getArguments().get(0), new StringLiteral("ignored failure message"))); - } - return node; - } - } + protected abstract Object optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level); } diff --git a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java index c9a4f82b7cd8..45c030cff3ac 100644 --- a/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java +++ b/presto-thrift-connector/src/test/java/com/facebook/presto/connector/thrift/integration/ThriftQueryRunner.java @@ -34,6 +34,7 @@ import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; @@ -254,6 +255,12 @@ public TestingAccessControlManager getAccessControl() return source.getAccessControl(); } + @Override + public ExpressionOptimizerManager getExpressionManager() + { + return source.getExpressionManager(); + } + @Override public MaterializedResult execute(String sql) {