From fe05eb8fa3b205b3212c25541e32b34f2167b540 Mon Sep 17 00:00:00 2001 From: milastdbx Date: Tue, 30 Apr 2024 08:25:34 +0800 Subject: [PATCH] [SPARK-47741] Added stack overflow handling in parser ### What changes were proposed in this pull request? Parser can throw stack overflow because of complex queries being generated. We need to throw proper error, so clients can adjust their code based on error class. This PR proposes we throw `ParserException` with proper error class, when parser hits `StackOverflow` ### Why are the changes needed? So client can catch and act on proper error classes parser errors. ### Does this PR introduce _any_ user-facing change? Yes it adds new error ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #45896 from milastdbx/dev/milast/fixExecImmStackTrace. Authored-by: milastdbx Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 7 +++ .../spark/sql/errors/QueryParsingErrors.scala | 6 ++ .../catalyst/parser/AbstractSqlParser.scala | 57 +++++++++++++------ .../sql/errors/QueryParsingErrorsSuite.scala | 16 ++++++ .../ExecuteImmediateEndToEndSuite.scala | 29 +++++++++- 5 files changed, 98 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 43af804fc3aeb..5791d74154162 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1318,6 +1318,13 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_PARSE_TOO_COMPLEX" : { + "message" : [ + "The statement, including potential SQL functions and referenced views, was too complex to parse.", + "To mitigate this error divide the statement into multiple, less complex chunks." + ], + "sqlState" : "54001" + }, "FIELDS_ALREADY_EXISTS" : { "message" : [ "Cannot column, because already exists in ." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 9d0d4ea799746..752e69f65c913 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -36,6 +36,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0001", ctx) } + def parserStackOverflow(parserRuleContext: ParserRuleContext): Throwable = { + throw new ParseException( + errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", + ctx = parserRuleContext) + } + def insertOverwriteDirectoryUnsupportedError(): Throwable = { SparkException.internalError("INSERT OVERWRITE DIRECTORY is not supported.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala index 2d6fabaaef68a..96b9b9006c9cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import org.antlr.v4.runtime.ParserRuleContext + import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin @@ -30,44 +32,56 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface { override def astBuilder: AstBuilder /** Creates Expression for a given SQL string. */ - override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => - val ctx = parser.singleExpression() - withOrigin(ctx, Some(sqlText)) { - astBuilder.visitSingleExpression(ctx) + override def parseExpression(sqlText: String): Expression = + parse(sqlText) { parser => + val ctx = parser.singleExpression() + withErrorHandling(ctx, Some(sqlText)) { + astBuilder.visitSingleExpression(ctx) + } } - } /** Creates TableIdentifier for a given SQL string. */ - override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) - } + override def parseTableIdentifier(sqlText: String): TableIdentifier = + parse(sqlText) { parser => + val ctx = parser.singleTableIdentifier() + withErrorHandling(ctx, Some(sqlText)) { + astBuilder.visitSingleTableIdentifier(ctx) + } + } /** Creates FunctionIdentifier for a given SQL string. */ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + val ctx = parser.singleFunctionIdentifier() + withErrorHandling(ctx, Some(sqlText)) { + astBuilder.visitSingleFunctionIdentifier(ctx) + } } } /** Creates a multi-part identifier for a given SQL string */ override def parseMultipartIdentifier(sqlText: String): Seq[String] = { parse(sqlText) { parser => - astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier()) + val ctx = parser.singleMultipartIdentifier() + withErrorHandling(ctx, Some(sqlText)) { + astBuilder.visitSingleMultipartIdentifier(ctx) + } } } /** Creates LogicalPlan for a given SQL string of query. */ - override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser => - val ctx = parser.query() - withOrigin(ctx, Some(sqlText)) { - astBuilder.visitQuery(ctx) + override def parseQuery(sqlText: String): LogicalPlan = + parse(sqlText) { parser => + val ctx = parser.query() + withErrorHandling(ctx, Some(sqlText)) { + astBuilder.visitQuery(ctx) + } } - } /** Creates LogicalPlan for a given SQL string. */ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => val ctx = parser.singleStatement() - withOrigin(ctx, Some(sqlText)) { + withErrorHandling(ctx, Some(sqlText)) { astBuilder.visitSingleStatement(ctx) match { case plan: LogicalPlan => plan case _ => @@ -76,4 +90,15 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface { } } } + + def withErrorHandling[T](ctx: ParserRuleContext, sqlText: Option[String])(toResult: => T): T = { + withOrigin(ctx, sqlText) { + try { + toResult + } catch { + case so: StackOverflowError => + throw QueryParsingErrors.parserStackOverflow(ctx) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index d381dae6ea293..5babce0ddb8dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -32,6 +32,22 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL intercept[ParseException](sql(sqlText).collect()) } + test("PARSE_STACK_OVERFLOW_ERROR: Stack overflow hit") { + val query = (1 to 20000).map(x => "SELECT 1 as a").mkString(" UNION ALL ") + val e = intercept[ParseException] { + spark.sql(query) + } + checkError( + exception = parseException(query), + errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", + parameters = Map(), + context = ExpectedContext( + query, + start = 0, + stop = query.length - 1) + ) + } + test("EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES: duplicate aliases provided in using statement") { val query = "EXECUTE IMMEDIATE 'SELECT 1707 WHERE ? = 1' USING 1 as first" + ", 2 as first, 3 as second, 4 as second, 5 as third" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala index 41ddcef89b7d4..6b0f0b5582dc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExecuteImmediateEndToEndSuite.scala @@ -16,7 +16,8 @@ */ package org.apache.spark.sql.execution -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.test.SharedSparkSession class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession { @@ -36,4 +37,30 @@ class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession { spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;") } } + + test("EXEC IMMEDIATE STACK OVERFLOW") { + try { + spark.sql("DECLARE parm = 1;") + val query = (1 to 20000).map(x => "SELECT 1 as a").mkString(" UNION ALL ") + Seq( + s"EXECUTE IMMEDIATE '$query'", + s"EXECUTE IMMEDIATE '$query' INTO parm").foreach { q => + val e = intercept[ParseException] { + spark.sql(q) + } + + checkError( + exception = intercept[ParseException](sql(query).collect()), + errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", + parameters = Map(), + context = ExpectedContext( + query, + start = 0, + stop = query.length - 1) + ) + } + } finally { + spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;") + } + } }