diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 26edbe15da9fd..a0f447dba798e 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -48,15 +48,15 @@ compoundOrSingleStatement ; singleCompoundStatement - : BEGIN compoundBody END SEMICOLON? EOF + : BEGIN compoundBody? END SEMICOLON? EOF ; beginEndCompoundBlock - : beginLabel? BEGIN compoundBody END endLabel? + : beginLabel? BEGIN compoundBody? END endLabel? ; compoundBody - : (compoundStatements+=compoundStatement SEMICOLON)* + : (compoundStatements+=compoundStatement SEMICOLON)+ ; compoundStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 882e895cc7f02..fad4fcefc1d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -144,7 +144,9 @@ class AstBuilder extends DataTypeAstBuilder override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = { val labelCtx = new SqlScriptingLabelContext() - visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = true, labelCtx) + Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, None, allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, None)) } private def visitCompoundBodyImpl( @@ -191,12 +193,9 @@ class AstBuilder extends DataTypeAstBuilder labelCtx: SqlScriptingLabelContext): CompoundBody = { val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) - val body = visitCompoundBodyImpl( - ctx.compoundBody(), - Some(labelText), - allowVarDeclare = true, - labelCtx - ) + val body = Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, Some(labelText), allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, Some(labelText))) labelCtx.exitLabeledScope(Option(ctx.beginLabel())) body } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index ab647f83b42a4..c9e2f42e164f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -82,7 +82,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } } - test("empty BEGIN END block") { + test("empty singleCompoundStatement") { val sqlScriptText = """ |BEGIN @@ -91,6 +91,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.isEmpty) } + test("empty beginEndCompoundBlock") { + val sqlScriptText = + """ + |BEGIN + | BEGIN + | END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CompoundBody]) + val innerBody = tree.collection.head.asInstanceOf[CompoundBody] + assert(innerBody.collection.isEmpty) + } + test("multiple ; in row - should fail") { val sqlScriptText = """ @@ -439,6 +453,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(ifStmt.conditions.head.getText == "1=1") } + test("if with empty body") { + val sqlScriptText = + """BEGIN + | IF 1 = 1 THEN + | END IF; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'IF'", "hint" -> "")) + } + test("if else") { val sqlScriptText = """BEGIN @@ -623,6 +652,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("while with empty body") { + val sqlScriptText = + """BEGIN + | WHILE 1 = 1 DO + | END WHILE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'WHILE'", "hint" -> "")) + } + test("while with complex condition") { val sqlScriptText = """ @@ -1067,6 +1111,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(repeatStmt.label.contains("lbl")) } + test("repeat with empty body") { + val sqlScriptText = + """BEGIN + | REPEAT UNTIL 1 = 1 + | END REPEAT; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'1'", "hint" -> "")) + } + test("repeat with complex condition") { val sqlScriptText = """ @@ -1197,6 +1256,22 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(caseStmt.conditions.head.getText == "1 = 1") } + test("searched case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE + | WHEN 1 = 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } + test("searched case statement - multi when") { val sqlScriptText = """ @@ -1335,6 +1410,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) } + test("simple case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE 1 + | WHEN 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } test("simple case statement - multi when") { val sqlScriptText = @@ -1482,6 +1572,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("loop with empty body") { + val sqlScriptText = + """BEGIN + | LOOP + | END LOOP; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'LOOP'", "hint" -> "")) + } + test("loop with if else block") { val sqlScriptText = """BEGIN @@ -1960,6 +2065,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no label") { val sqlScriptText = """ @@ -2076,6 +2196,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - no variable - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no variable - no label") { val sqlScriptText = """