Skip to content

Commit

Permalink
[SPARK-47741] Added stack overflow handling in parser
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Parser can throw stack overflow because of complex queries being generated.
We need to throw proper error, so clients can adjust their code based on error class.

This PR proposes we throw `ParserException` with proper error class, when parser hits `StackOverflow`

### Why are the changes needed?
So client can catch and act on proper error classes parser errors.

### Does this PR introduce _any_ user-facing change?
Yes it adds new error

### How was this patch tested?
Unit test

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45896 from milastdbx/dev/milast/fixExecImmStackTrace.

Authored-by: milastdbx <milan.stefanovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
milastdbx authored and cloud-fan committed Apr 30, 2024
1 parent 3fbcb26 commit fe05eb8
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 17 deletions.
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,13 @@
],
"sqlState" : "2203G"
},
"FAILED_TO_PARSE_TOO_COMPLEX" : {
"message" : [
"The statement, including potential SQL functions and referenced views, was too complex to parse.",
"To mitigate this error divide the statement into multiple, less complex chunks."
],
"sqlState" : "54001"
},
"FIELDS_ALREADY_EXISTS" : {
"message" : [
"Cannot <op> column, because <fieldNames> already exists in <struct>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0001", ctx)
}

def parserStackOverflow(parserRuleContext: ParserRuleContext): Throwable = {
throw new ParseException(
errorClass = "FAILED_TO_PARSE_TOO_COMPLEX",
ctx = parserRuleContext)
}

def insertOverwriteDirectoryUnsupportedError(): Throwable = {
SparkException.internalError("INSERT OVERWRITE DIRECTORY is not supported.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.parser

import org.antlr.v4.runtime.ParserRuleContext

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
Expand All @@ -30,44 +32,56 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
override def astBuilder: AstBuilder

/** Creates Expression for a given SQL string. */
override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
val ctx = parser.singleExpression()
withOrigin(ctx, Some(sqlText)) {
astBuilder.visitSingleExpression(ctx)
override def parseExpression(sqlText: String): Expression =
parse(sqlText) { parser =>
val ctx = parser.singleExpression()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleExpression(ctx)
}
}
}

/** Creates TableIdentifier for a given SQL string. */
override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
}
override def parseTableIdentifier(sqlText: String): TableIdentifier =
parse(sqlText) { parser =>
val ctx = parser.singleTableIdentifier()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleTableIdentifier(ctx)
}
}

/** Creates FunctionIdentifier for a given SQL string. */
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
parse(sqlText) { parser =>
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
val ctx = parser.singleFunctionIdentifier()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleFunctionIdentifier(ctx)
}
}
}

/** Creates a multi-part identifier for a given SQL string */
override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
parse(sqlText) { parser =>
astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
val ctx = parser.singleMultipartIdentifier()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleMultipartIdentifier(ctx)
}
}
}

/** Creates LogicalPlan for a given SQL string of query. */
override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.query()
withOrigin(ctx, Some(sqlText)) {
astBuilder.visitQuery(ctx)
override def parseQuery(sqlText: String): LogicalPlan =
parse(sqlText) { parser =>
val ctx = parser.query()
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitQuery(ctx)
}
}
}

/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
val ctx = parser.singleStatement()
withOrigin(ctx, Some(sqlText)) {
withErrorHandling(ctx, Some(sqlText)) {
astBuilder.visitSingleStatement(ctx) match {
case plan: LogicalPlan => plan
case _ =>
Expand All @@ -76,4 +90,15 @@ abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
}
}
}

def withErrorHandling[T](ctx: ParserRuleContext, sqlText: Option[String])(toResult: => T): T = {
withOrigin(ctx, sqlText) {
try {
toResult
} catch {
case so: StackOverflowError =>
throw QueryParsingErrors.parserStackOverflow(ctx)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL
intercept[ParseException](sql(sqlText).collect())
}

test("PARSE_STACK_OVERFLOW_ERROR: Stack overflow hit") {
val query = (1 to 20000).map(x => "SELECT 1 as a").mkString(" UNION ALL ")
val e = intercept[ParseException] {
spark.sql(query)
}
checkError(
exception = parseException(query),
errorClass = "FAILED_TO_PARSE_TOO_COMPLEX",
parameters = Map(),
context = ExpectedContext(
query,
start = 0,
stop = query.length - 1)
)
}

test("EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES: duplicate aliases provided in using statement") {
val query = "EXECUTE IMMEDIATE 'SELECT 1707 WHERE ? = 1' USING 1 as first" +
", 2 as first, 3 as second, 4 as second, 5 as third"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.test.SharedSparkSession

class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession {
Expand All @@ -36,4 +37,30 @@ class ExecuteImmediateEndToEndSuite extends QueryTest with SharedSparkSession {
spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;")
}
}

test("EXEC IMMEDIATE STACK OVERFLOW") {
try {
spark.sql("DECLARE parm = 1;")
val query = (1 to 20000).map(x => "SELECT 1 as a").mkString(" UNION ALL ")
Seq(
s"EXECUTE IMMEDIATE '$query'",
s"EXECUTE IMMEDIATE '$query' INTO parm").foreach { q =>
val e = intercept[ParseException] {
spark.sql(q)
}

checkError(
exception = intercept[ParseException](sql(query).collect()),
errorClass = "FAILED_TO_PARSE_TOO_COMPLEX",
parameters = Map(),
context = ExpectedContext(
query,
start = 0,
stop = query.length - 1)
)
}
} finally {
spark.sql("DROP TEMPORARY VARIABLE IF EXISTS parm;")
}
}
}

0 comments on commit fe05eb8

Please sign in to comment.