[test](Nereids) runtime filter unit cases not rely on NereidPlanner to generate PhysicalPlan anymore (#12740)

This PR:
1. add rewrite and implement method to PlanChecker
2. improve unit tests of runtime filter
This commit is contained in:
mch_ucchi
2022-09-19 19:53:55 +08:00
committed by GitHub
parent 1339eef33c
commit 94d73abf2a
4 changed files with 109 additions and 31 deletions

View File

@ -75,6 +75,7 @@ public class RuntimeFilterTranslator {
SlotRef src = ctx.findSlotRef(filter.getSrcExpr().getExprId());
SlotRef target = context.getExprIdToOlapScanNodeSlotRef().get(filter.getTargetExpr().getExprId());
if (target == null) {
context.setTargetNullCount();
return;
}
org.apache.doris.planner.RuntimeFilter origFilter

View File

@ -71,6 +71,8 @@ public class RuntimeFilterContext {
private final FilterSizeLimits limits;
private int targetNullCount = 0;
public RuntimeFilterContext(SessionVariable sessionVariable) {
this.sessionVariable = sessionVariable;
this.limits = new FilterSizeLimits(sessionVariable);
@ -178,4 +180,13 @@ public class RuntimeFilterContext {
} while (expr != null);
return builder.build();
}
public void setTargetNullCount() {
targetNullCount++;
}
@VisibleForTesting
public int getTargetNullCount() {
return targetNullCount;
}
}

View File

@ -18,17 +18,15 @@
package org.apache.doris.nereids.postprocess;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
import org.apache.doris.nereids.datasets.ssb.SSBUtils;
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.processor.post.PlanPostProcessors;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.nereids.util.PlanChecker;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@ -49,7 +47,8 @@ public class RuntimeFilterTest extends SSBTestBase {
public void testGenerateRuntimeFilter() throws AnalysisException {
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
}
@Test
@ -64,14 +63,17 @@ public class RuntimeFilterTest extends SSBTestBase {
String sql
= "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_name", "s_name", "c_city", "s_city", "c_nation", "s_nation"));
}
@Test
public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException {
String sql = SSBUtils.Q4_1;
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4);
Assertions.assertTrue(filters.size() == 4
&& checkRuntimeFilterExprs(filters, "p_partkey", "lo_partkey", "s_suppkey", "lo_suppkey",
"c_custkey", "lo_custkey", "lo_orderdate", "d_datekey"));
}
@Test
@ -81,7 +83,8 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
Assertions.assertTrue(filters.size() == 2
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
}
@Test
@ -91,7 +94,8 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "s_suppkey", "c_custkey"));
}
@Test
@ -102,7 +106,9 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 3);
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
"s_suppkey", "c_custkey"));
}
@Test
@ -113,21 +119,24 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b"
+ " on b.c_custkey = a.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 2);
Assertions.assertTrue(filters.size() == 2
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
}
@Test
public void testCrossJoin() throws AnalysisException {
String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
}
@Test
public void testSubQueryAlias() throws AnalysisException {
String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 1);
Assertions.assertTrue(filters.size() == 1
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
}
@Test
@ -156,8 +165,9 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " on t1.p_partkey = t2.lo_partkey\n"
+ " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 4);
Assertions.assertTrue(filters.size() == 4
&& checkRuntimeFilterExprs(filters, "lo_partkey", "p_partkey", "lo_partkey", "p_partkey",
"c_region", "s_region", "lo_custkey", "c_custkey"));
}
@Test
@ -168,10 +178,11 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 5);
Assertions.assertTrue(filters.size() == 5
&& checkRuntimeFilterExprs(filters, "lo_custkey", "c_custkey", "c_custkey", "lo_custkey",
"d_datekey", "lo_orderdate", "s_suppkey", "c_custkey", "lo_custkey", "c_custkey"));
}
/*
@Test
public void testPushDownThroughUnsupportedJoinType() throws AnalysisException {
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
@ -180,23 +191,33 @@ public class RuntimeFilterTest extends SSBTestBase {
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
Assertions.assertTrue(filters.size() == 5);
Assertions.assertTrue(filters.size() == 3
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
"lo_custkey", "c_custkey"));
}
*/
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) throws AnalysisException {
NereidsPlanner planner = new NereidsPlanner(createStatementCtx(sql));
PhysicalPlan plan = planner.plan(new NereidsParser().parseSingle(sql), PhysicalProperties.ANY);
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) {
PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
.rewrite()
.implement();
PhysicalPlan plan = checker.getPhysicalPlan();
new PlanPostProcessors(checker.getCascadesContext()).process(plan);
System.out.println(plan.treeString());
PlanTranslatorContext context = new PlanTranslatorContext(planner.getCascadesContext());
PlanFragment root = new PhysicalPlanTranslator().translatePlan(plan, context);
System.out.println(root.getFragmentId());
if (context.getRuntimeTranslator().isPresent()) {
RuntimeFilterContext ctx = planner.getCascadesContext().getRuntimeFilterContext();
Assertions.assertEquals(ctx.getNereidsRuntimeFilter().size(), ctx.getLegacyFilters().size());
return Optional.of(ctx.getNereidsRuntimeFilter());
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext(checker.getCascadesContext()));
RuntimeFilterContext context = checker.getCascadesContext().getRuntimeFilterContext();
List<RuntimeFilter> filters = context.getNereidsRuntimeFilter();
Assertions.assertEquals(filters.size(), context.getLegacyFilters().size() + context.getTargetNullCount());
return Optional.of(filters);
}
private boolean checkRuntimeFilterExprs(List<RuntimeFilter> filters, String... colNames) {
int idx = 0;
for (RuntimeFilter filter : filters) {
if (!checkRuntimeFilterExpr(filter, colNames[idx++], colNames[idx++])) {
return false;
}
}
return Optional.empty();
return true;
}
private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) {

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
@ -32,10 +33,12 @@ import org.apache.doris.nereids.pattern.PatternDescriptor;
import org.apache.doris.nereids.pattern.PatternMatcher;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
@ -43,7 +46,9 @@ import com.google.common.base.Supplier;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
/**
* Utility to apply rules to plan and check output plan matches the expected pattern.
@ -54,6 +59,8 @@ public class PlanChecker {
private Plan parsedPlan;
private PhysicalPlan physicalPlan;
public PlanChecker(ConnectContext connectContext) {
this.connectContext = connectContext;
}
@ -88,7 +95,11 @@ public class PlanChecker {
return this;
}
public PlanChecker applyTopDown(RuleFactory rule) {
public PlanChecker applyTopDown(RuleFactory ruleFactory) {
return applyTopDown(ruleFactory.buildRules());
}
public PlanChecker applyTopDown(List<Rule> rule) {
cascadesContext.topDownRewrite(rule);
MemoValidator.validate(cascadesContext.getMemo());
return this;
@ -132,6 +143,40 @@ public class PlanChecker {
return this;
}
public PlanChecker rewrite() {
new NereidsRewriteJobExecutor(cascadesContext).execute();
return this;
}
public PlanChecker implement() {
Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot());
Assertions.assertTrue(plan instanceof PhysicalPlan);
physicalPlan = ((PhysicalPlan) plan);
return this;
}
public PhysicalPlan getPhysicalPlan() {
return physicalPlan;
}
private Plan transformToPhysicalPlan(Group group) {
PhysicalPlan current = null;
loop:
for (Rule rule : RuleSet.IMPLEMENTATION_RULES) {
GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), group.getLogicalExpression());
for (Plan plan : matching) {
Plan after = rule.transform(plan, cascadesContext).get(0);
if (after instanceof PhysicalPlan) {
current = (PhysicalPlan) after;
break loop;
}
}
}
Assertions.assertNotNull(current);
return current.withChildren(group.getLogicalExpression().children()
.stream().map(this::transformToPhysicalPlan).collect(Collectors.toList()));
}
public PlanChecker transform(PatternMatcher patternMatcher) {
return transform(cascadesContext.getMemo().getRoot(), patternMatcher);
}