From 4cdf9f2a23d9a2660024e7b95996413a05343db1 Mon Sep 17 00:00:00 2001 From: Shuo Wang Date: Wed, 17 Aug 2022 20:17:26 +0800 Subject: [PATCH] [Enhancement](Nereids) Refine nereids parser. (#11839) 1. Use ParseException in nereids parser. 2. Add check utils in the parser test. 3. Distinguish matchesFromRoot and matches when checking plans. --- .../doris/nereids/analyzer/UnboundSlot.java | 5 + .../nereids/parser/LogicalPlanBuilder.java | 51 ++++---- .../doris/nereids/parser/LimitClauseTest.java | 113 ++++++++---------- .../nereids/parser/NereidsParserTest.java | 21 ++-- .../doris/nereids/parser/ParserTestBase.java | 35 ++++++ .../rewrite/logical/ColumnPruningTest.java | 8 +- .../expressions/ExpressionParserTest.java | 38 +++++- .../nereids/trees/expressions/ViewTest.java | 4 +- .../nereids/util/AnalyzeSubQueryTest.java | 6 +- .../doris/nereids/util/ExceptionChecker.java | 48 ++++++++ .../nereids/util/ExpressionParseChecker.java | 46 +++++++ .../nereids/util/GroupMatchingUtils.java | 43 +++++++ .../doris/nereids/util/ParseChecker.java | 29 +++++ .../doris/nereids/util/PlanChecker.java | 34 +++++- .../doris/nereids/util/PlanParseChecker.java | 64 ++++++++++ 15 files changed, 419 insertions(+), 126 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/parser/ParserTestBase.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExceptionChecker.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionParseChecker.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/util/ParseChecker.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundSlot.java index 60b662865b..045882396a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundSlot.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.util.Utils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; @@ -32,6 +33,10 @@ import java.util.Objects; public class UnboundSlot extends Slot implements Unbound { private final List nameParts; + public UnboundSlot(String... nameParts) { + this(ImmutableList.copyOf(nameParts)); + } + public UnboundSlot(List nameParts) { this.nameParts = Objects.requireNonNull(nameParts, "nameParts can not be null"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index 215a993b42..9cb2b03d06 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -213,8 +213,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { @Override public LogicalPlan visitRegularQuerySpecification(RegularQuerySpecificationContext ctx) { return ParserUtils.withOrigin(ctx, () -> { - // TODO: support on row relation - LogicalPlan relation = withRelation(Optional.ofNullable(ctx.fromClause())); + // TODO: support one row relation + if (ctx.fromClause() == null) { + throw new ParseException("Unsupported one row relation", ctx); + } + + LogicalPlan relation = visitFromClause(ctx.fromClause()); return withSelectQuerySpecification( ctx, relation, ctx.selectClause(), @@ -326,8 +330,8 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { case DorisParser.NSEQ: return new NullSafeEqual(left, right); default: - throw new IllegalStateException("Unsupported comparison expression: " - + operator.getSymbol().getText()); + throw new ParseException("Unsupported comparison expression: " + + operator.getSymbol().getText(), ctx); } }); } @@ -356,7 +360,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { case DorisParser.OR: return new Or(left, right); default: - throw new IllegalStateException("Unsupported logical binary type: " + ctx.operator.getText()); + throw new ParseException("Unsupported logical binary type: " + ctx.operator.getText(), ctx); } }); } @@ -387,7 +391,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { case DorisParser.MINUS: //TODO: Add single operator subtraction default: - throw new IllegalStateException("Unsupported arithmetic unary type: " + ctx.operator.getText()); + throw new ParseException("Unsupported arithmetic unary type: " + ctx.operator.getText(), ctx); } }); } @@ -401,7 +405,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { int type = ctx.operator.getType(); if (left instanceof IntervalLiteral) { if (type != DorisParser.PLUS) { - throw new IllegalArgumentException("Only supported: " + Operator.ADD); + throw new ParseException("Only supported: " + Operator.ADD, ctx); } IntervalLiteral interval = (IntervalLiteral) left; return new TimestampArithmetic(Operator.ADD, right, interval.value(), interval.timeUnit(), true); @@ -414,7 +418,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } else if (type == DorisParser.MINUS) { op = Operator.SUBTRACT; } else { - throw new IllegalArgumentException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT); + throw new ParseException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT, ctx); } IntervalLiteral interval = (IntervalLiteral) right; return new TimestampArithmetic(op, left, interval.value(), interval.timeUnit(), false); @@ -433,8 +437,8 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { case DorisParser.MINUS: return new Subtract(left, right); default: - throw new IllegalStateException( - "Unsupported arithmetic binary type: " + ctx.operator.getText()); + throw new ParseException( + "Unsupported arithmetic binary type: " + ctx.operator.getText(), ctx); } }); }); @@ -537,7 +541,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { case "DATETIME": return new DateTimeLiteral(value); default: - throw new IllegalStateException("Unsupported data type : " + type); + throw new ParseException("Unsupported data type : " + type, ctx); } } @@ -552,7 +556,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { return new UnboundSlot(nameParts); } else { // todo: base is an expression, may be not a table name. - throw new IllegalStateException("Unsupported dereference expression: " + ctx.getText()); + throw new ParseException("Unsupported dereference expression: " + ctx.getText(), ctx); } }); } @@ -614,7 +618,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { LogicalPlan right = plan(ctx.relationPrimary()); if (ctx.LATERAL() != null) { if (!(right instanceof LogicalSubQueryAlias)) { - throw new IllegalStateException("lateral join right table should be sub-query"); + throw new ParseException("lateral join right table should be sub-query", ctx); } } return right; @@ -719,7 +723,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { private LogicalPlan withSort(LogicalPlan input, Optional sortCtx) { return input.optionalMap(sortCtx, () -> { List orderKeys = visit(sortCtx.get().sortItem(), OrderKey.class); - return new LogicalSort(orderKeys, input); + return new LogicalSort<>(orderKeys, input); }); } @@ -732,7 +736,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { if (input instanceof LogicalSort) { offset = Long.parseLong(offsetToken.getText()); } else { - throw new IllegalStateException("OFFSET requires an ORDER BY clause"); + throw new ParseException("OFFSET requires an ORDER BY clause", limitCtx.get()); } } return new LogicalLimit<>(limit, offset, input); @@ -766,14 +770,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { }); } - private LogicalPlan withRelation(Optional ctx) { - if (ctx.isPresent()) { - return visitFromClause(ctx.get()); - } else { - throw new IllegalStateException("Unsupported one row relation"); - } - } - /** * Join one more [[LogicalPlan]]s to the current logical plan. */ @@ -916,7 +912,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } break; default: - throw new IllegalStateException("Unsupported predicate type: " + ctx.kind.getText()); + throw new ParseException("Unsupported predicate type: " + ctx.kind.getText(), ctx); } return ctx.NOT() != null ? new Not(outExpression) : outExpression; }); @@ -925,14 +921,13 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { private List getNamedExpressions(NamedExpressionSeqContext namedCtx) { return ParserUtils.withOrigin(namedCtx, () -> { List expressions = visit(namedCtx.namedExpression(), Expression.class); - List namedExpressions = expressions.stream().map(expression -> { + return expressions.stream().map(expression -> { if (expression instanceof NamedExpression) { return (NamedExpression) expression; } else { return new UnboundAlias(expression); } }).collect(ImmutableList.toImmutableList()); - return namedExpressions; }); } @@ -947,8 +942,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } public List withInList(PredicateContext ctx) { - List expressions = ctx.expression().stream() - .map(this::getExpression).collect(ImmutableList.toImmutableList()); - return expressions; + return ctx.expression().stream().map(this::getExpression).collect(ImmutableList.toImmutableList()); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/LimitClauseTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/LimitClauseTest.java index 40ed482528..fe407eed1c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/LimitClauseTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/LimitClauseTest.java @@ -17,92 +17,73 @@ package org.apache.doris.nereids.parser; -import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.exceptions.ParseException; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -public class LimitClauseTest { +public class LimitClauseTest extends ParserTestBase { @Test public void testLimit() { - NereidsParser nereidsParser = new NereidsParser(); - String sql = "SELECT b FROM test order by a limit 3 offset 100"; - LogicalPlan logicalPlan = nereidsParser.parseSingle(sql); - Assertions.assertTrue(logicalPlan instanceof LogicalLimit); - LogicalLimit limit = (LogicalLimit) logicalPlan; - Assertions.assertEquals(3, limit.getLimit()); - Assertions.assertEquals(100, limit.getOffset()); - Assertions.assertEquals(1, limit.children().size()); - Assertions.assertTrue(limit.child(0) instanceof LogicalSort); + parsePlan("SELECT b FROM test order by a limit 3 offset 100") + .matchesFromRoot( + logicalLimit( + logicalSort() + ).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 100) + ); - sql = "SELECT b FROM test order by a limit 100, 3"; - logicalPlan = nereidsParser.parseSingle(sql); - Assertions.assertTrue(logicalPlan instanceof LogicalLimit); - limit = (LogicalLimit) logicalPlan; - Assertions.assertEquals(3, limit.getLimit()); - Assertions.assertEquals(100, limit.getOffset()); - Assertions.assertEquals(1, limit.children().size()); - Assertions.assertTrue(limit.child(0) instanceof LogicalSort); + parsePlan("SELECT b FROM test order by a limit 100, 3") + .matchesFromRoot( + logicalLimit( + logicalSort() + ).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 100) + ); - sql = "SELECT b FROM test limit 3"; - logicalPlan = nereidsParser.parseSingle(sql); - Assertions.assertTrue(logicalPlan instanceof LogicalLimit); - limit = (LogicalLimit) logicalPlan; - Assertions.assertEquals(3, limit.getLimit()); - Assertions.assertEquals(0, limit.getOffset()); - Assertions.assertEquals(1, limit.children().size()); - Assertions.assertTrue(limit.child(0) instanceof LogicalProject); - sql = "SELECT b FROM test order by a limit 3"; - logicalPlan = nereidsParser.parseSingle(sql); - Assertions.assertTrue(logicalPlan instanceof LogicalLimit); - limit = (LogicalLimit) logicalPlan; - Assertions.assertEquals(3, limit.getLimit()); - Assertions.assertEquals(0, limit.getOffset()); - Assertions.assertEquals(1, limit.children().size()); - Assertions.assertTrue(limit.child(0) instanceof LogicalSort); + parsePlan("SELECT b FROM test limit 3") + .matchesFromRoot(logicalLimit().when(limit -> limit.getLimit() == 3 && limit.getOffset() == 0)); + + + parsePlan("SELECT b FROM test order by a limit 3") + .matchesFromRoot( + logicalLimit( + logicalSort() + ).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 0) + ); } @Test public void testLimitExceptionCase() { - NereidsParser nereidsParser = new NereidsParser(); - IllegalStateException exception = Assertions.assertThrows( - IllegalStateException.class, - () -> { - String sql = "SELECT b FROM test limit 3 offset 100"; - nereidsParser.parseSingle(sql); - }); - Assertions.assertEquals("OFFSET requires an ORDER BY clause", - exception.getMessage()); - - exception = Assertions.assertThrows( - IllegalStateException.class, - () -> { - String sql = "SELECT b FROM test limit 100, 3"; - nereidsParser.parseSingle(sql); - }); - Assertions.assertEquals("OFFSET requires an ORDER BY clause", - exception.getMessage()); + parsePlan("SELECT b FROM test limit 3 offset 100") + .assertThrowsExactly(ParseException.class) + .assertMessageContains("\n" + + "OFFSET requires an ORDER BY clause(line 1, pos19)\n" + + "\n" + + "== SQL ==\n" + + "SELECT b FROM test limit 3 offset 100\n" + + "-------------------^^^"); + parsePlan("SELECT b FROM test limit 100, 3") + .assertThrowsExactly(ParseException.class) + .assertMessageContains("\n" + + "OFFSET requires an ORDER BY clause(line 1, pos19)\n" + + "\n" + + "== SQL ==\n" + + "SELECT b FROM test limit 100, 3\n" + + "-------------------^^^"); } @Test public void testNoLimit() { - NereidsParser nereidsParser = new NereidsParser(); - String sql = "select a from tbl order by x"; - LogicalPlan root = nereidsParser.parseSingle(sql); - Assertions.assertTrue(root instanceof LogicalSort); + parsePlan("select a from tbl order by x").matchesFromRoot(logicalSort()); } - @Test public void testNoQueryOrganization() { - NereidsParser nereidsParser = new NereidsParser(); - String sql = "select a from tbl"; - LogicalPlan root = nereidsParser.parseSingle(sql); - Assertions.assertTrue(root instanceof LogicalProject); + parsePlan("select a from tbl") + .matchesFromRoot( + logicalProject( + unboundRelation() + ) + ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java index e5c1ffba03..208a45b09b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/NereidsParserTest.java @@ -34,7 +34,7 @@ import org.junit.jupiter.api.Test; import java.util.List; -public class NereidsParserTest { +public class NereidsParserTest extends ParserTestBase { @Test public void testParseMultiple() { @@ -60,22 +60,17 @@ public class NereidsParserTest { @Test public void testErrorListener() { - Exception exception = Assertions.assertThrows(ParseException.class, () -> { - String sql = "select * from t1 where a = 1 illegal_symbol"; - NereidsParser nereidsParser = new NereidsParser(); - nereidsParser.parseSingle(sql); - }); - Assertions.assertEquals("\nextraneous input 'illegal_symbol' expecting {, ';'}(line 1, pos29)\n", - exception.getMessage()); + parsePlan("select * from t1 where a = 1 illegal_symbol") + .assertThrowsExactly(ParseException.class) + .assertMessageEquals("\nextraneous input 'illegal_symbol' expecting {, ';'}(line 1, pos29)\n"); } @Test public void testPostProcessor() { - String sql = "select `AD``D` from t1 where a = 1"; - NereidsParser nereidsParser = new NereidsParser(); - LogicalPlan logicalPlan = nereidsParser.parseSingle(sql); - LogicalProject logicalProject = (LogicalProject) logicalPlan; - Assertions.assertEquals("AD`D", logicalProject.getProjects().get(0).getName()); + parsePlan("select `AD``D` from t1 where a = 1") + .matchesFromRoot( + logicalProject().when(p -> "AD`D".equals(p.getProjects().get(0).getName())) + ); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/ParserTestBase.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/ParserTestBase.java new file mode 100644 index 0000000000..cd35c5aef8 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/parser/ParserTestBase.java @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.parser; + +import org.apache.doris.nereids.util.ExpressionParseChecker; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanParseChecker; + +/** + * Base class to check SQL parsing result. + */ +public abstract class ParserTestBase implements PatternMatchSupported { + public PlanParseChecker parsePlan(String sql) { + return new PlanParseChecker(sql); + } + + public ExpressionParseChecker parseExpression(String sql) { + return new ExpressionParseChecker(sql); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java index 54b552a152..eb7777b368 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java @@ -60,7 +60,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch .analyze("select id,name,grade from student left join score on student.id = score.sid" + " where score.grade > 60") .applyTopDown(new ColumnPruning()) - .matches( + .matchesFromRoot( logicalProject( logicalFilter( logicalProject( @@ -92,7 +92,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch + "from student left join score on student.id = score.sid " + "where score.grade > 60") .applyTopDown(new ColumnPruning()) - .matches( + .matchesFromRoot( logicalProject( logicalFilter( logicalProject( @@ -122,7 +122,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch PlanChecker.from(connectContext) .analyze("select id,name from student where age > 18") .applyTopDown(new ColumnPruning()) - .matches( + .matchesFromRoot( logicalProject( logicalFilter( logicalProject().when(p -> getOutputQualifiedNames(p) @@ -144,7 +144,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch + "on score.cid = course.cid " + "where score.grade > 60") .applyTopDown(new ColumnPruning()) - .matches( + .matchesFromRoot( logicalProject( logicalFilter( logicalProject( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java index ece7f547e1..e6dd5de565 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java @@ -17,17 +17,34 @@ package org.apache.doris.nereids.trees.expressions; +import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.exceptions.ParseException; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.parser.ParserTestBase; import org.junit.jupiter.api.Test; -public class ExpressionParserTest { +public class ExpressionParserTest extends ParserTestBase { private static final NereidsParser PARSER = new NereidsParser(); + /** + * This method is deprecated. + *

+ * Please use utility functions `parsePlan `in {@link ParserTestBase} + * to get {@link org.apache.doris.nereids.util.PlanParseChecker}. + */ + @Deprecated private void assertSql(String sql) { PARSER.parseSingle(sql); } + /** + * This method is deprecated. + *

+ * Please use utility functions `parseExpression` in {@link ParserTestBase} + * to get {@link org.apache.doris.nereids.util.PlanParseChecker}. + */ + @Deprecated private void assertExpr(String expr) { Expression expression = PARSER.parseExpression(expr); System.out.println(expression.toSql()); @@ -41,12 +58,18 @@ public class ExpressionParserTest { @Test public void testExprBetweenPredicate() { - String sql = "c BETWEEN a AND b"; - assertExpr(sql); + parseExpression("c BETWEEN a AND b") + .assertEquals( + new Between( + new UnboundSlot("c"), + new UnboundSlot("a"), + new UnboundSlot("b") + ) + ); } @Test - public void testInPredicate() throws Exception { + public void testInPredicate() { String in = "select * from test1 where d1 in (1, 2, 3)"; assertSql(in); @@ -55,7 +78,7 @@ public class ExpressionParserTest { } @Test - public void testSqlAnd() throws Exception { + public void testSqlAnd() { String sql = "select * from test1 where a > 1 and b > 1"; assertSql(sql); } @@ -100,6 +123,11 @@ public class ExpressionParserTest { String subtract = "3 - 2"; assertExpr(subtract); + + parseExpression("3 += 2") + .assertThrowsExactly(ParseException.class) + .assertMessageContains("extraneous input '=' expecting {'("); + } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java index 3391c14b2d..cdba3ef62d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ViewTest.java @@ -114,7 +114,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported .analyze("SELECT * FROM V1") .applyTopDown(new EliminateAliasNode()) .applyTopDown(new MergeConsecutiveProjects()) - .matches( + .matchesFromRoot( logicalProject( logicalOlapScan() ) @@ -127,7 +127,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported .analyze("SELECT * FROM (SELECT * FROM V1 JOIN V2 ON V1.ID1 = V2.ID2) X JOIN (SELECT * FROM V1 JOIN V3 ON V1.ID1 = V3.ID2) Y ON X.ID1 = Y.ID3") .applyTopDown(new EliminateAliasNode()) .applyTopDown(new MergeConsecutiveProjects()) - .matches( + .matchesFromRoot( logicalProject( logicalJoin( logicalProject( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java index 00e9de4266..7d5117d75b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java @@ -103,7 +103,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat PlanChecker.from(connectContext) .analyze(testSql.get(0)) .applyTopDown(new EliminateAliasNode()) - .matches( + .matchesFromRoot( logicalProject( logicalProject( logicalOlapScan().when(o -> true) @@ -123,7 +123,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat PlanChecker.from(connectContext) .analyze(testSql.get(1)) .applyTopDown(new EliminateAliasNode()) - .matches( + .matchesFromRoot( logicalProject( logicalJoin( logicalOlapScan(), @@ -154,7 +154,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat PlanChecker.from(connectContext) .analyze(testSql.get(5)) .applyTopDown(new EliminateAliasNode()) - .matches( + .matchesFromRoot( logicalProject( logicalJoin( logicalOlapScan(), diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExceptionChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExceptionChecker.java new file mode 100644 index 0000000000..ebea43fee5 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExceptionChecker.java @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.junit.jupiter.api.Assertions; + +import java.util.function.Function; + +/** + * Helper to check exception message. + */ +public class ExceptionChecker { + private final Throwable exception; + + public ExceptionChecker(Throwable exception) { + this.exception = exception; + } + + public ExceptionChecker assertMessageEquals(String message) { + Assertions.assertEquals(message, exception.getMessage()); + return this; + } + + public ExceptionChecker assertMessageContains(String context) { + Assertions.assertTrue(exception.getMessage().contains(context)); + return this; + } + + public ExceptionChecker assertWith(Function asserter) { + Assertions.assertTrue(asserter.apply(exception)); + return this; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionParseChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionParseChecker.java new file mode 100644 index 0000000000..3ac6745547 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionParseChecker.java @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import org.junit.jupiter.api.Assertions; + +public class ExpressionParseChecker extends ParseChecker { + private final Supplier parsedSupplier; + + public ExpressionParseChecker(String sql) { + super(sql); + this.parsedSupplier = Suppliers.memoize(() -> PARSER.parseExpression(sql)); + } + + public ExpressionParseChecker assertEquals(Expression expected) { + Assertions.assertEquals(expected, parsedSupplier.get()); + return this; + } + + public ExceptionChecker assertThrows(Class expectedType) { + return new ExceptionChecker(Assertions.assertThrows(expectedType, parsedSupplier::get)); + } + + public ExceptionChecker assertThrowsExactly(Class expectedType) { + return new ExceptionChecker(Assertions.assertThrowsExactly(expectedType, parsedSupplier::get)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java new file mode 100644 index 0000000000..5b6338aaf9 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/GroupMatchingUtils.java @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.pattern.GroupExpressionMatching; +import org.apache.doris.nereids.pattern.Pattern; +import org.apache.doris.nereids.trees.plans.Plan; + +public class GroupMatchingUtils { + + public static boolean topDownFindMatching(Group group, Pattern pattern) { + GroupExpression logicalExpr = group.getLogicalExpression(); + GroupExpressionMatching matchingResult = new GroupExpressionMatching(pattern, logicalExpr); + if (matchingResult.iterator().hasNext()) { + return true; + } else { + for (Group childGroup : logicalExpr.children()) { + boolean checkResult = topDownFindMatching(childGroup, pattern); + if (checkResult) { + return true; + } + } + } + return false; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ParseChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ParseChecker.java new file mode 100644 index 0000000000..432d55150f --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ParseChecker.java @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.parser.NereidsParser; + +public abstract class ParseChecker { + protected static final NereidsParser PARSER = new NereidsParser(); + protected final String sql; + + public ParseChecker(String sql) { + this.sql = sql; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index fb8a9453cf..bd4b805540 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -25,8 +25,11 @@ import org.apache.doris.nereids.rules.RuleFactory; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.qe.ConnectContext; +import com.google.common.base.Supplier; import org.junit.jupiter.api.Assertions; +import java.util.function.Consumer; + /** * Utility to apply rules to plan and check output plan matches the expected pattern. */ @@ -34,6 +37,7 @@ public class PlanChecker { private ConnectContext connectContext; private CascadesContext cascadesContext; + private Plan parsedPlan; public PlanChecker(ConnectContext connectContext) { this.connectContext = connectContext; @@ -44,6 +48,18 @@ public class PlanChecker { this.cascadesContext = cascadesContext; } + public PlanChecker checkParse(String sql, Consumer consumer) { + PlanParseChecker checker = new PlanParseChecker(sql); + consumer.accept(checker); + parsedPlan = checker.parsedSupplier.get(); + return this; + } + + public PlanChecker analyze() { + MemoTestUtils.createCascadesContext(connectContext, parsedPlan); + return this; + } + public PlanChecker analyze(String sql) { this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, sql); this.cascadesContext.newAnalyzer().analyze(); @@ -66,12 +82,22 @@ public class PlanChecker { return this; } + public void matchesFromRoot(PatternDescriptor patternDesc) { + Memo memo = cascadesContext.getMemo(); + assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern, + memo.getRoot().getLogicalExpression()).iterator().hasNext()); + } + public void matches(PatternDescriptor patternDesc) { Memo memo = cascadesContext.getMemo(); - GroupExpressionMatching matchResult = new GroupExpressionMatching(patternDesc.pattern, - memo.getRoot().getLogicalExpression()); - Assertions.assertTrue(matchResult.iterator().hasNext(), () -> - "pattern not match, plan :\n" + memo.getRoot().getLogicalExpression().getPlan().treeString() + "\n" + assertMatches(memo, () -> GroupMatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern)); + } + + private void assertMatches(Memo memo, Supplier asserter) { + Assertions.assertTrue(asserter.get(), + () -> "pattern not match, plan :\n" + + memo.getRoot().getLogicalExpression().getPlan().treeString() + + "\n" ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java new file mode 100644 index 0000000000..55032a2864 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanParseChecker.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.memo.Memo; +import org.apache.doris.nereids.pattern.GroupExpressionMatching; +import org.apache.doris.nereids.pattern.PatternDescriptor; +import org.apache.doris.nereids.trees.plans.Plan; + +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import org.junit.jupiter.api.Assertions; + +public class PlanParseChecker extends ParseChecker { + final Supplier parsedSupplier; + + public PlanParseChecker(String sql) { + super(sql); + this.parsedSupplier = Suppliers.memoize(() -> PARSER.parseSingle(sql)); + } + + public PlanParseChecker matches(PatternDescriptor patternDesc) { + assertMatches(() -> GroupMatchingUtils.topDownFindMatching( + new Memo(parsedSupplier.get()).getRoot(), patternDesc.pattern)); + return this; + } + + public PlanParseChecker matchesFromRoot(PatternDescriptor patternDesc) { + assertMatches(() -> new GroupExpressionMatching(patternDesc.pattern, + new Memo(parsedSupplier.get()).getRoot().getLogicalExpression()) + .iterator().hasNext()); + return this; + } + + public ExceptionChecker assertThrows(Class expectedType) { + return new ExceptionChecker(Assertions.assertThrows(expectedType, parsedSupplier::get)); + } + + public ExceptionChecker assertThrowsExactly(Class expectedType) { + return new ExceptionChecker(Assertions.assertThrowsExactly(expectedType, parsedSupplier::get)); + } + + private void assertMatches(Supplier assertResultSupplier) { + Assertions.assertTrue(assertResultSupplier.get(), + () -> "pattern not match,\ninput SQL:\n" + sql + + "\n, parsed plan :\n" + parsedSupplier.get().treeString() + "\n" + ); + } +}