[Enhancement](Nereids) Refine nereids parser. (#11839)

1. Use ParseException in nereids parser.
2. Add check utils in the parser test.
3. Distinguish matchesFromRoot and matches when checking plans.
This commit is contained in:
Shuo Wang
2022-08-17 20:17:26 +08:00
committed by GitHub
parent 11dc5cad83
commit 4cdf9f2a23
15 changed files with 419 additions and 126 deletions

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
@ -32,6 +33,10 @@ import java.util.Objects;
public class UnboundSlot extends Slot implements Unbound {
private final List<String> nameParts;
public UnboundSlot(String... nameParts) {
this(ImmutableList.copyOf(nameParts));
}
public UnboundSlot(List<String> nameParts) {
this.nameParts = Objects.requireNonNull(nameParts, "nameParts can not be null");
}

View File

@ -213,8 +213,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
@Override
public LogicalPlan visitRegularQuerySpecification(RegularQuerySpecificationContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
// TODO: support on row relation
LogicalPlan relation = withRelation(Optional.ofNullable(ctx.fromClause()));
// TODO: support one row relation
if (ctx.fromClause() == null) {
throw new ParseException("Unsupported one row relation", ctx);
}
LogicalPlan relation = visitFromClause(ctx.fromClause());
return withSelectQuerySpecification(
ctx, relation,
ctx.selectClause(),
@ -326,8 +330,8 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.NSEQ:
return new NullSafeEqual(left, right);
default:
throw new IllegalStateException("Unsupported comparison expression: "
+ operator.getSymbol().getText());
throw new ParseException("Unsupported comparison expression: "
+ operator.getSymbol().getText(), ctx);
}
});
}
@ -356,7 +360,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.OR:
return new Or(left, right);
default:
throw new IllegalStateException("Unsupported logical binary type: " + ctx.operator.getText());
throw new ParseException("Unsupported logical binary type: " + ctx.operator.getText(), ctx);
}
});
}
@ -387,7 +391,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.MINUS:
//TODO: Add single operator subtraction
default:
throw new IllegalStateException("Unsupported arithmetic unary type: " + ctx.operator.getText());
throw new ParseException("Unsupported arithmetic unary type: " + ctx.operator.getText(), ctx);
}
});
}
@ -401,7 +405,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
int type = ctx.operator.getType();
if (left instanceof IntervalLiteral) {
if (type != DorisParser.PLUS) {
throw new IllegalArgumentException("Only supported: " + Operator.ADD);
throw new ParseException("Only supported: " + Operator.ADD, ctx);
}
IntervalLiteral interval = (IntervalLiteral) left;
return new TimestampArithmetic(Operator.ADD, right, interval.value(), interval.timeUnit(), true);
@ -414,7 +418,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
} else if (type == DorisParser.MINUS) {
op = Operator.SUBTRACT;
} else {
throw new IllegalArgumentException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT);
throw new ParseException("Only supported: " + Operator.ADD + " and " + Operator.SUBTRACT, ctx);
}
IntervalLiteral interval = (IntervalLiteral) right;
return new TimestampArithmetic(op, left, interval.value(), interval.timeUnit(), false);
@ -433,8 +437,8 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.MINUS:
return new Subtract(left, right);
default:
throw new IllegalStateException(
"Unsupported arithmetic binary type: " + ctx.operator.getText());
throw new ParseException(
"Unsupported arithmetic binary type: " + ctx.operator.getText(), ctx);
}
});
});
@ -537,7 +541,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case "DATETIME":
return new DateTimeLiteral(value);
default:
throw new IllegalStateException("Unsupported data type : " + type);
throw new ParseException("Unsupported data type : " + type, ctx);
}
}
@ -552,7 +556,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
return new UnboundSlot(nameParts);
} else {
// todo: base is an expression, may be not a table name.
throw new IllegalStateException("Unsupported dereference expression: " + ctx.getText());
throw new ParseException("Unsupported dereference expression: " + ctx.getText(), ctx);
}
});
}
@ -614,7 +618,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
LogicalPlan right = plan(ctx.relationPrimary());
if (ctx.LATERAL() != null) {
if (!(right instanceof LogicalSubQueryAlias)) {
throw new IllegalStateException("lateral join right table should be sub-query");
throw new ParseException("lateral join right table should be sub-query", ctx);
}
}
return right;
@ -719,7 +723,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
private LogicalPlan withSort(LogicalPlan input, Optional<SortClauseContext> sortCtx) {
return input.optionalMap(sortCtx, () -> {
List<OrderKey> orderKeys = visit(sortCtx.get().sortItem(), OrderKey.class);
return new LogicalSort(orderKeys, input);
return new LogicalSort<>(orderKeys, input);
});
}
@ -732,7 +736,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
if (input instanceof LogicalSort) {
offset = Long.parseLong(offsetToken.getText());
} else {
throw new IllegalStateException("OFFSET requires an ORDER BY clause");
throw new ParseException("OFFSET requires an ORDER BY clause", limitCtx.get());
}
}
return new LogicalLimit<>(limit, offset, input);
@ -766,14 +770,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
});
}
private LogicalPlan withRelation(Optional<FromClauseContext> ctx) {
if (ctx.isPresent()) {
return visitFromClause(ctx.get());
} else {
throw new IllegalStateException("Unsupported one row relation");
}
}
/**
* Join one more [[LogicalPlan]]s to the current logical plan.
*/
@ -916,7 +912,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
break;
default:
throw new IllegalStateException("Unsupported predicate type: " + ctx.kind.getText());
throw new ParseException("Unsupported predicate type: " + ctx.kind.getText(), ctx);
}
return ctx.NOT() != null ? new Not(outExpression) : outExpression;
});
@ -925,14 +921,13 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
private List<NamedExpression> getNamedExpressions(NamedExpressionSeqContext namedCtx) {
return ParserUtils.withOrigin(namedCtx, () -> {
List<Expression> expressions = visit(namedCtx.namedExpression(), Expression.class);
List<NamedExpression> namedExpressions = expressions.stream().map(expression -> {
return expressions.stream().map(expression -> {
if (expression instanceof NamedExpression) {
return (NamedExpression) expression;
} else {
return new UnboundAlias(expression);
}
}).collect(ImmutableList.toImmutableList());
return namedExpressions;
});
}
@ -947,8 +942,6 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
public List<Expression> withInList(PredicateContext ctx) {
List<Expression> expressions = ctx.expression().stream()
.map(this::getExpression).collect(ImmutableList.toImmutableList());
return expressions;
return ctx.expression().stream().map(this::getExpression).collect(ImmutableList.toImmutableList());
}
}

View File

@ -17,92 +17,73 @@
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.exceptions.ParseException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class LimitClauseTest {
public class LimitClauseTest extends ParserTestBase {
@Test
public void testLimit() {
NereidsParser nereidsParser = new NereidsParser();
String sql = "SELECT b FROM test order by a limit 3 offset 100";
LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);
Assertions.assertTrue(logicalPlan instanceof LogicalLimit);
LogicalLimit limit = (LogicalLimit) logicalPlan;
Assertions.assertEquals(3, limit.getLimit());
Assertions.assertEquals(100, limit.getOffset());
Assertions.assertEquals(1, limit.children().size());
Assertions.assertTrue(limit.child(0) instanceof LogicalSort);
parsePlan("SELECT b FROM test order by a limit 3 offset 100")
.matchesFromRoot(
logicalLimit(
logicalSort()
).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 100)
);
sql = "SELECT b FROM test order by a limit 100, 3";
logicalPlan = nereidsParser.parseSingle(sql);
Assertions.assertTrue(logicalPlan instanceof LogicalLimit);
limit = (LogicalLimit) logicalPlan;
Assertions.assertEquals(3, limit.getLimit());
Assertions.assertEquals(100, limit.getOffset());
Assertions.assertEquals(1, limit.children().size());
Assertions.assertTrue(limit.child(0) instanceof LogicalSort);
parsePlan("SELECT b FROM test order by a limit 100, 3")
.matchesFromRoot(
logicalLimit(
logicalSort()
).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 100)
);
sql = "SELECT b FROM test limit 3";
logicalPlan = nereidsParser.parseSingle(sql);
Assertions.assertTrue(logicalPlan instanceof LogicalLimit);
limit = (LogicalLimit) logicalPlan;
Assertions.assertEquals(3, limit.getLimit());
Assertions.assertEquals(0, limit.getOffset());
Assertions.assertEquals(1, limit.children().size());
Assertions.assertTrue(limit.child(0) instanceof LogicalProject);
sql = "SELECT b FROM test order by a limit 3";
logicalPlan = nereidsParser.parseSingle(sql);
Assertions.assertTrue(logicalPlan instanceof LogicalLimit);
limit = (LogicalLimit) logicalPlan;
Assertions.assertEquals(3, limit.getLimit());
Assertions.assertEquals(0, limit.getOffset());
Assertions.assertEquals(1, limit.children().size());
Assertions.assertTrue(limit.child(0) instanceof LogicalSort);
parsePlan("SELECT b FROM test limit 3")
.matchesFromRoot(logicalLimit().when(limit -> limit.getLimit() == 3 && limit.getOffset() == 0));
parsePlan("SELECT b FROM test order by a limit 3")
.matchesFromRoot(
logicalLimit(
logicalSort()
).when(limit -> limit.getLimit() == 3 && limit.getOffset() == 0)
);
}
@Test
public void testLimitExceptionCase() {
NereidsParser nereidsParser = new NereidsParser();
IllegalStateException exception = Assertions.assertThrows(
IllegalStateException.class,
() -> {
String sql = "SELECT b FROM test limit 3 offset 100";
nereidsParser.parseSingle(sql);
});
Assertions.assertEquals("OFFSET requires an ORDER BY clause",
exception.getMessage());
exception = Assertions.assertThrows(
IllegalStateException.class,
() -> {
String sql = "SELECT b FROM test limit 100, 3";
nereidsParser.parseSingle(sql);
});
Assertions.assertEquals("OFFSET requires an ORDER BY clause",
exception.getMessage());
parsePlan("SELECT b FROM test limit 3 offset 100")
.assertThrowsExactly(ParseException.class)
.assertMessageContains("\n"
+ "OFFSET requires an ORDER BY clause(line 1, pos19)\n"
+ "\n"
+ "== SQL ==\n"
+ "SELECT b FROM test limit 3 offset 100\n"
+ "-------------------^^^");
parsePlan("SELECT b FROM test limit 100, 3")
.assertThrowsExactly(ParseException.class)
.assertMessageContains("\n"
+ "OFFSET requires an ORDER BY clause(line 1, pos19)\n"
+ "\n"
+ "== SQL ==\n"
+ "SELECT b FROM test limit 100, 3\n"
+ "-------------------^^^");
}
@Test
public void testNoLimit() {
NereidsParser nereidsParser = new NereidsParser();
String sql = "select a from tbl order by x";
LogicalPlan root = nereidsParser.parseSingle(sql);
Assertions.assertTrue(root instanceof LogicalSort);
parsePlan("select a from tbl order by x").matchesFromRoot(logicalSort());
}
@Test
public void testNoQueryOrganization() {
NereidsParser nereidsParser = new NereidsParser();
String sql = "select a from tbl";
LogicalPlan root = nereidsParser.parseSingle(sql);
Assertions.assertTrue(root instanceof LogicalProject);
parsePlan("select a from tbl")
.matchesFromRoot(
logicalProject(
unboundRelation()
)
);
}
}

View File

@ -34,7 +34,7 @@ import org.junit.jupiter.api.Test;
import java.util.List;
public class NereidsParserTest {
public class NereidsParserTest extends ParserTestBase {
@Test
public void testParseMultiple() {
@ -60,22 +60,17 @@ public class NereidsParserTest {
@Test
public void testErrorListener() {
Exception exception = Assertions.assertThrows(ParseException.class, () -> {
String sql = "select * from t1 where a = 1 illegal_symbol";
NereidsParser nereidsParser = new NereidsParser();
nereidsParser.parseSingle(sql);
});
Assertions.assertEquals("\nextraneous input 'illegal_symbol' expecting {<EOF>, ';'}(line 1, pos29)\n",
exception.getMessage());
parsePlan("select * from t1 where a = 1 illegal_symbol")
.assertThrowsExactly(ParseException.class)
.assertMessageEquals("\nextraneous input 'illegal_symbol' expecting {<EOF>, ';'}(line 1, pos29)\n");
}
@Test
public void testPostProcessor() {
String sql = "select `AD``D` from t1 where a = 1";
NereidsParser nereidsParser = new NereidsParser();
LogicalPlan logicalPlan = nereidsParser.parseSingle(sql);
LogicalProject<Plan> logicalProject = (LogicalProject) logicalPlan;
Assertions.assertEquals("AD`D", logicalProject.getProjects().get(0).getName());
parsePlan("select `AD``D` from t1 where a = 1")
.matchesFromRoot(
logicalProject().when(p -> "AD`D".equals(p.getProjects().get(0).getName()))
);
}
@Test

View File

@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.util.ExpressionParseChecker;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanParseChecker;
/**
* Base class to check SQL parsing result.
*/
public abstract class ParserTestBase implements PatternMatchSupported {
public PlanParseChecker parsePlan(String sql) {
return new PlanParseChecker(sql);
}
public ExpressionParseChecker parseExpression(String sql) {
return new ExpressionParseChecker(sql);
}
}

View File

@ -60,7 +60,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
.analyze("select id,name,grade from student left join score on student.id = score.sid"
+ " where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalProject(
@ -92,7 +92,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
+ "from student left join score on student.id = score.sid "
+ "where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalProject(
@ -122,7 +122,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
PlanChecker.from(connectContext)
.analyze("select id,name from student where age > 18")
.applyTopDown(new ColumnPruning())
.matches(
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalProject().when(p -> getOutputQualifiedNames(p)
@ -144,7 +144,7 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
+ "on score.cid = course.cid "
+ "where score.grade > 60")
.applyTopDown(new ColumnPruning())
.matches(
.matchesFromRoot(
logicalProject(
logicalFilter(
logicalProject(

View File

@ -17,17 +17,34 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.ParseException;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.parser.ParserTestBase;
import org.junit.jupiter.api.Test;
public class ExpressionParserTest {
public class ExpressionParserTest extends ParserTestBase {
private static final NereidsParser PARSER = new NereidsParser();
/**
* This method is deprecated.
* <p>
* Please use utility functions `parsePlan `in {@link ParserTestBase}
* to get {@link org.apache.doris.nereids.util.PlanParseChecker}.
*/
@Deprecated
private void assertSql(String sql) {
PARSER.parseSingle(sql);
}
/**
* This method is deprecated.
* <p>
* Please use utility functions `parseExpression` in {@link ParserTestBase}
* to get {@link org.apache.doris.nereids.util.PlanParseChecker}.
*/
@Deprecated
private void assertExpr(String expr) {
Expression expression = PARSER.parseExpression(expr);
System.out.println(expression.toSql());
@ -41,12 +58,18 @@ public class ExpressionParserTest {
@Test
public void testExprBetweenPredicate() {
String sql = "c BETWEEN a AND b";
assertExpr(sql);
parseExpression("c BETWEEN a AND b")
.assertEquals(
new Between(
new UnboundSlot("c"),
new UnboundSlot("a"),
new UnboundSlot("b")
)
);
}
@Test
public void testInPredicate() throws Exception {
public void testInPredicate() {
String in = "select * from test1 where d1 in (1, 2, 3)";
assertSql(in);
@ -55,7 +78,7 @@ public class ExpressionParserTest {
}
@Test
public void testSqlAnd() throws Exception {
public void testSqlAnd() {
String sql = "select * from test1 where a > 1 and b > 1";
assertSql(sql);
}
@ -100,6 +123,11 @@ public class ExpressionParserTest {
String subtract = "3 - 2";
assertExpr(subtract);
parseExpression("3 += 2")
.assertThrowsExactly(ParseException.class)
.assertMessageContains("extraneous input '=' expecting {'(");
}
@Test

View File

@ -114,7 +114,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
.analyze("SELECT * FROM V1")
.applyTopDown(new EliminateAliasNode())
.applyTopDown(new MergeConsecutiveProjects())
.matches(
.matchesFromRoot(
logicalProject(
logicalOlapScan()
)
@ -127,7 +127,7 @@ public class ViewTest extends TestWithFeService implements PatternMatchSupported
.analyze("SELECT * FROM (SELECT * FROM V1 JOIN V2 ON V1.ID1 = V2.ID2) X JOIN (SELECT * FROM V1 JOIN V3 ON V1.ID1 = V3.ID2) Y ON X.ID1 = Y.ID3")
.applyTopDown(new EliminateAliasNode())
.applyTopDown(new MergeConsecutiveProjects())
.matches(
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalProject(

View File

@ -103,7 +103,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
PlanChecker.from(connectContext)
.analyze(testSql.get(0))
.applyTopDown(new EliminateAliasNode())
.matches(
.matchesFromRoot(
logicalProject(
logicalProject(
logicalOlapScan().when(o -> true)
@ -123,7 +123,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
PlanChecker.from(connectContext)
.analyze(testSql.get(1))
.applyTopDown(new EliminateAliasNode())
.matches(
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalOlapScan(),
@ -154,7 +154,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat
PlanChecker.from(connectContext)
.analyze(testSql.get(5))
.applyTopDown(new EliminateAliasNode())
.matches(
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalOlapScan(),

View File

@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.junit.jupiter.api.Assertions;
import java.util.function.Function;
/**
* Helper to check exception message.
*/
public class ExceptionChecker {
private final Throwable exception;
public ExceptionChecker(Throwable exception) {
this.exception = exception;
}
public ExceptionChecker assertMessageEquals(String message) {
Assertions.assertEquals(message, exception.getMessage());
return this;
}
public ExceptionChecker assertMessageContains(String context) {
Assertions.assertTrue(exception.getMessage().contains(context));
return this;
}
public ExceptionChecker assertWith(Function<Throwable, Boolean> asserter) {
Assertions.assertTrue(asserter.apply(exception));
return this;
}
}

View File

@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.junit.jupiter.api.Assertions;
public class ExpressionParseChecker extends ParseChecker {
private final Supplier<Expression> parsedSupplier;
public ExpressionParseChecker(String sql) {
super(sql);
this.parsedSupplier = Suppliers.memoize(() -> PARSER.parseExpression(sql));
}
public ExpressionParseChecker assertEquals(Expression expected) {
Assertions.assertEquals(expected, parsedSupplier.get());
return this;
}
public <T extends Throwable> ExceptionChecker assertThrows(Class<T> expectedType) {
return new ExceptionChecker(Assertions.assertThrows(expectedType, parsedSupplier::get));
}
public <T extends Throwable> ExceptionChecker assertThrowsExactly(Class<T> expectedType) {
return new ExceptionChecker(Assertions.assertThrowsExactly(expectedType, parsedSupplier::get));
}
}

View File

@ -0,0 +1,43 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.trees.plans.Plan;
public class GroupMatchingUtils {
public static boolean topDownFindMatching(Group group, Pattern<? extends Plan> pattern) {
GroupExpression logicalExpr = group.getLogicalExpression();
GroupExpressionMatching matchingResult = new GroupExpressionMatching(pattern, logicalExpr);
if (matchingResult.iterator().hasNext()) {
return true;
} else {
for (Group childGroup : logicalExpr.children()) {
boolean checkResult = topDownFindMatching(childGroup, pattern);
if (checkResult) {
return true;
}
}
}
return false;
}
}

View File

@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.parser.NereidsParser;
public abstract class ParseChecker {
protected static final NereidsParser PARSER = new NereidsParser();
protected final String sql;
public ParseChecker(String sql) {
this.sql = sql;
}
}

View File

@ -25,8 +25,11 @@ import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Supplier;
import org.junit.jupiter.api.Assertions;
import java.util.function.Consumer;
/**
* Utility to apply rules to plan and check output plan matches the expected pattern.
*/
@ -34,6 +37,7 @@ public class PlanChecker {
private ConnectContext connectContext;
private CascadesContext cascadesContext;
private Plan parsedPlan;
public PlanChecker(ConnectContext connectContext) {
this.connectContext = connectContext;
@ -44,6 +48,18 @@ public class PlanChecker {
this.cascadesContext = cascadesContext;
}
public PlanChecker checkParse(String sql, Consumer<PlanParseChecker> consumer) {
PlanParseChecker checker = new PlanParseChecker(sql);
consumer.accept(checker);
parsedPlan = checker.parsedSupplier.get();
return this;
}
public PlanChecker analyze() {
MemoTestUtils.createCascadesContext(connectContext, parsedPlan);
return this;
}
public PlanChecker analyze(String sql) {
this.cascadesContext = MemoTestUtils.createCascadesContext(connectContext, sql);
this.cascadesContext.newAnalyzer().analyze();
@ -66,12 +82,22 @@ public class PlanChecker {
return this;
}
public void matchesFromRoot(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern,
memo.getRoot().getLogicalExpression()).iterator().hasNext());
}
public void matches(PatternDescriptor<? extends Plan> patternDesc) {
Memo memo = cascadesContext.getMemo();
GroupExpressionMatching matchResult = new GroupExpressionMatching(patternDesc.pattern,
memo.getRoot().getLogicalExpression());
Assertions.assertTrue(matchResult.iterator().hasNext(), () ->
"pattern not match, plan :\n" + memo.getRoot().getLogicalExpression().getPlan().treeString() + "\n"
assertMatches(memo, () -> GroupMatchingUtils.topDownFindMatching(memo.getRoot(), patternDesc.pattern));
}
private void assertMatches(Memo memo, Supplier<Boolean> asserter) {
Assertions.assertTrue(asserter.get(),
() -> "pattern not match, plan :\n"
+ memo.getRoot().getLogicalExpression().getPlan().treeString()
+ "\n"
);
}

View File

@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.pattern.GroupExpressionMatching;
import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.trees.plans.Plan;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.junit.jupiter.api.Assertions;
public class PlanParseChecker extends ParseChecker {
final Supplier<Plan> parsedSupplier;
public PlanParseChecker(String sql) {
super(sql);
this.parsedSupplier = Suppliers.memoize(() -> PARSER.parseSingle(sql));
}
public PlanParseChecker matches(PatternDescriptor<? extends Plan> patternDesc) {
assertMatches(() -> GroupMatchingUtils.topDownFindMatching(
new Memo(parsedSupplier.get()).getRoot(), patternDesc.pattern));
return this;
}
public PlanParseChecker matchesFromRoot(PatternDescriptor<? extends Plan> patternDesc) {
assertMatches(() -> new GroupExpressionMatching(patternDesc.pattern,
new Memo(parsedSupplier.get()).getRoot().getLogicalExpression())
.iterator().hasNext());
return this;
}
public <T extends Throwable> ExceptionChecker assertThrows(Class<T> expectedType) {
return new ExceptionChecker(Assertions.assertThrows(expectedType, parsedSupplier::get));
}
public <T extends Throwable> ExceptionChecker assertThrowsExactly(Class<T> expectedType) {
return new ExceptionChecker(Assertions.assertThrowsExactly(expectedType, parsedSupplier::get));
}
private void assertMatches(Supplier<Boolean> assertResultSupplier) {
Assertions.assertTrue(assertResultSupplier.get(),
() -> "pattern not match,\ninput SQL:\n" + sql
+ "\n, parsed plan :\n" + parsedSupplier.get().treeString() + "\n"
);
}
}