From c149942f0f9692ad83eb683abc0a8bcaa1ef70ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 6 Dec 2024 12:30:57 +0800 Subject: [PATCH] [SPARK-50449][SQL] Fix SQL Scripting grammar allowing empty bodies for loops, IF and CASE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Before this PR, SQL Scripting grammar allowed for loops, IF and CASE to have empty bodies. Example: `WHILE 1 = 1 DO END WHILE;` If they have an empty body, an internal error is thrown during execution. This PR changes the grammar so that loops, IF and CASE must have at least one statement in their bodies. Note that this does not completely fix the internal error issue. It is still possible to have something like ``` WHILE 1 = 1 DO BEGIN END; END WHILE; ``` where the same error is still thrown, except this construct is correct grammar wise. This issue will be fixed by a separate PR, as non-trivial interpreter logic changes are required. ### Why are the changes needed? The existing grammar was wrong. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests that make sure parsing loops, IF and CASE with empty bodies throws an error. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48989 from dusantism-db/scripting-empty-bodies-fix. Authored-by: Dušan Tišma Signed-off-by: Wenchen Fan --- .../sql/catalyst/parser/SqlBaseParser.g4 | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 13 +- .../parser/SqlScriptingParserSuite.scala | 137 +++++++++++++++++- 3 files changed, 145 insertions(+), 11 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 26edbe15da9fd..a0f447dba798e 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -48,15 +48,15 @@ compoundOrSingleStatement ; singleCompoundStatement - : BEGIN compoundBody END SEMICOLON? EOF + : BEGIN compoundBody? END SEMICOLON? EOF ; beginEndCompoundBlock - : beginLabel? BEGIN compoundBody END endLabel? + : beginLabel? BEGIN compoundBody? END endLabel? ; compoundBody - : (compoundStatements+=compoundStatement SEMICOLON)* + : (compoundStatements+=compoundStatement SEMICOLON)+ ; compoundStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 882e895cc7f02..fad4fcefc1d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -144,7 +144,9 @@ class AstBuilder extends DataTypeAstBuilder override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = { val labelCtx = new SqlScriptingLabelContext() - visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = true, labelCtx) + Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, None, allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, None)) } private def visitCompoundBodyImpl( @@ -191,12 +193,9 @@ class AstBuilder extends DataTypeAstBuilder labelCtx: SqlScriptingLabelContext): CompoundBody = { val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) - val body = visitCompoundBodyImpl( - ctx.compoundBody(), - Some(labelText), - allowVarDeclare = true, - labelCtx - ) + val body = Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, Some(labelText), allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, Some(labelText))) labelCtx.exitLabeledScope(Option(ctx.beginLabel())) body } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index ab647f83b42a4..c9e2f42e164f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -82,7 +82,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } } - test("empty BEGIN END block") { + test("empty singleCompoundStatement") { val sqlScriptText = """ |BEGIN @@ -91,6 +91,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.isEmpty) } + test("empty beginEndCompoundBlock") { + val sqlScriptText = + """ + |BEGIN + | BEGIN + | END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CompoundBody]) + val innerBody = tree.collection.head.asInstanceOf[CompoundBody] + assert(innerBody.collection.isEmpty) + } + test("multiple ; in row - should fail") { val sqlScriptText = """ @@ -439,6 +453,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(ifStmt.conditions.head.getText == "1=1") } + test("if with empty body") { + val sqlScriptText = + """BEGIN + | IF 1 = 1 THEN + | END IF; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'IF'", "hint" -> "")) + } + test("if else") { val sqlScriptText = """BEGIN @@ -623,6 +652,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("while with empty body") { + val sqlScriptText = + """BEGIN + | WHILE 1 = 1 DO + | END WHILE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'WHILE'", "hint" -> "")) + } + test("while with complex condition") { val sqlScriptText = """ @@ -1067,6 +1111,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(repeatStmt.label.contains("lbl")) } + test("repeat with empty body") { + val sqlScriptText = + """BEGIN + | REPEAT UNTIL 1 = 1 + | END REPEAT; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'1'", "hint" -> "")) + } + test("repeat with complex condition") { val sqlScriptText = """ @@ -1197,6 +1256,22 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(caseStmt.conditions.head.getText == "1 = 1") } + test("searched case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE + | WHEN 1 = 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } + test("searched case statement - multi when") { val sqlScriptText = """ @@ -1335,6 +1410,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) } + test("simple case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE 1 + | WHEN 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } test("simple case statement - multi when") { val sqlScriptText = @@ -1482,6 +1572,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("loop with empty body") { + val sqlScriptText = + """BEGIN + | LOOP + | END LOOP; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'LOOP'", "hint" -> "")) + } + test("loop with if else block") { val sqlScriptText = """BEGIN @@ -1960,6 +2065,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no label") { val sqlScriptText = """ @@ -2076,6 +2196,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - no variable - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no variable - no label") { val sqlScriptText = """