[test](Nereids) runtime filter unit cases not rely on NereidPlanner to generate PhysicalPlan anymore (#12740)
This PR: 1. add rewrite and implement method to PlanChecker 2. improve unit tests of runtime filter
This commit is contained in:
@ -75,6 +75,7 @@ public class RuntimeFilterTranslator {
|
||||
SlotRef src = ctx.findSlotRef(filter.getSrcExpr().getExprId());
|
||||
SlotRef target = context.getExprIdToOlapScanNodeSlotRef().get(filter.getTargetExpr().getExprId());
|
||||
if (target == null) {
|
||||
context.setTargetNullCount();
|
||||
return;
|
||||
}
|
||||
org.apache.doris.planner.RuntimeFilter origFilter
|
||||
|
||||
@ -71,6 +71,8 @@ public class RuntimeFilterContext {
|
||||
|
||||
private final FilterSizeLimits limits;
|
||||
|
||||
private int targetNullCount = 0;
|
||||
|
||||
public RuntimeFilterContext(SessionVariable sessionVariable) {
|
||||
this.sessionVariable = sessionVariable;
|
||||
this.limits = new FilterSizeLimits(sessionVariable);
|
||||
@ -178,4 +180,13 @@ public class RuntimeFilterContext {
|
||||
} while (expr != null);
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public void setTargetNullCount() {
|
||||
targetNullCount++;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public int getTargetNullCount() {
|
||||
return targetNullCount;
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,17 +18,15 @@
|
||||
package org.apache.doris.nereids.postprocess;
|
||||
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.datasets.ssb.SSBTestBase;
|
||||
import org.apache.doris.nereids.datasets.ssb.SSBUtils;
|
||||
import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator;
|
||||
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.processor.post.PlanPostProcessors;
|
||||
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter;
|
||||
import org.apache.doris.planner.PlanFragment;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@ -49,7 +47,8 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
public void testGenerateRuntimeFilter() throws AnalysisException {
|
||||
String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 1);
|
||||
Assertions.assertTrue(filters.size() == 1
|
||||
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -64,14 +63,17 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
String sql
|
||||
= "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 3);
|
||||
Assertions.assertTrue(filters.size() == 3
|
||||
&& checkRuntimeFilterExprs(filters, "c_name", "s_name", "c_city", "s_city", "c_nation", "s_nation"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException {
|
||||
String sql = SSBUtils.Q4_1;
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 4);
|
||||
Assertions.assertTrue(filters.size() == 4
|
||||
&& checkRuntimeFilterExprs(filters, "p_partkey", "lo_partkey", "s_suppkey", "lo_suppkey",
|
||||
"c_custkey", "lo_custkey", "lo_orderdate", "d_datekey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -81,7 +83,8 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
|
||||
+ " on b.c_custkey = a.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 3);
|
||||
Assertions.assertTrue(filters.size() == 2
|
||||
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -91,7 +94,8 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
|
||||
+ " on b.c_custkey = a.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 3);
|
||||
Assertions.assertTrue(filters.size() == 1
|
||||
&& checkRuntimeFilterExprs(filters, "s_suppkey", "c_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -102,7 +106,9 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b"
|
||||
+ " on b.c_custkey = a.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 3);
|
||||
Assertions.assertTrue(filters.size() == 3
|
||||
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
|
||||
"s_suppkey", "c_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -113,21 +119,24 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b"
|
||||
+ " on b.c_custkey = a.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 2);
|
||||
Assertions.assertTrue(filters.size() == 2
|
||||
&& checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCrossJoin() throws AnalysisException {
|
||||
String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 1);
|
||||
Assertions.assertTrue(filters.size() == 1
|
||||
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSubQueryAlias() throws AnalysisException {
|
||||
String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 1);
|
||||
Assertions.assertTrue(filters.size() == 1
|
||||
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -156,8 +165,9 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " on t1.p_partkey = t2.lo_partkey\n"
|
||||
+ " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 4);
|
||||
|
||||
Assertions.assertTrue(filters.size() == 4
|
||||
&& checkRuntimeFilterExprs(filters, "lo_partkey", "p_partkey", "lo_partkey", "p_partkey",
|
||||
"c_region", "s_region", "lo_custkey", "c_custkey"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -168,10 +178,11 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
|
||||
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 5);
|
||||
Assertions.assertTrue(filters.size() == 5
|
||||
&& checkRuntimeFilterExprs(filters, "lo_custkey", "c_custkey", "c_custkey", "lo_custkey",
|
||||
"d_datekey", "lo_orderdate", "s_suppkey", "c_custkey", "lo_custkey", "c_custkey"));
|
||||
}
|
||||
|
||||
/*
|
||||
@Test
|
||||
public void testPushDownThroughUnsupportedJoinType() throws AnalysisException {
|
||||
String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates"
|
||||
@ -180,23 +191,33 @@ public class RuntimeFilterTest extends SSBTestBase {
|
||||
+ " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder"
|
||||
+ " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey";
|
||||
List<RuntimeFilter> filters = getRuntimeFilters(sql).get();
|
||||
Assertions.assertTrue(filters.size() == 5);
|
||||
Assertions.assertTrue(filters.size() == 3
|
||||
&& checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate",
|
||||
"lo_custkey", "c_custkey"));
|
||||
}
|
||||
*/
|
||||
|
||||
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) throws AnalysisException {
|
||||
NereidsPlanner planner = new NereidsPlanner(createStatementCtx(sql));
|
||||
PhysicalPlan plan = planner.plan(new NereidsParser().parseSingle(sql), PhysicalProperties.ANY);
|
||||
private Optional<List<RuntimeFilter>> getRuntimeFilters(String sql) {
|
||||
PlanChecker checker = PlanChecker.from(connectContext).analyze(sql)
|
||||
.rewrite()
|
||||
.implement();
|
||||
PhysicalPlan plan = checker.getPhysicalPlan();
|
||||
new PlanPostProcessors(checker.getCascadesContext()).process(plan);
|
||||
System.out.println(plan.treeString());
|
||||
PlanTranslatorContext context = new PlanTranslatorContext(planner.getCascadesContext());
|
||||
PlanFragment root = new PhysicalPlanTranslator().translatePlan(plan, context);
|
||||
System.out.println(root.getFragmentId());
|
||||
if (context.getRuntimeTranslator().isPresent()) {
|
||||
RuntimeFilterContext ctx = planner.getCascadesContext().getRuntimeFilterContext();
|
||||
Assertions.assertEquals(ctx.getNereidsRuntimeFilter().size(), ctx.getLegacyFilters().size());
|
||||
return Optional.of(ctx.getNereidsRuntimeFilter());
|
||||
new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext(checker.getCascadesContext()));
|
||||
RuntimeFilterContext context = checker.getCascadesContext().getRuntimeFilterContext();
|
||||
List<RuntimeFilter> filters = context.getNereidsRuntimeFilter();
|
||||
Assertions.assertEquals(filters.size(), context.getLegacyFilters().size() + context.getTargetNullCount());
|
||||
return Optional.of(filters);
|
||||
}
|
||||
|
||||
private boolean checkRuntimeFilterExprs(List<RuntimeFilter> filters, String... colNames) {
|
||||
int idx = 0;
|
||||
for (RuntimeFilter filter : filters) {
|
||||
if (!checkRuntimeFilterExpr(filter, colNames[idx++], colNames[idx++])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) {
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.memo.Memo;
|
||||
@ -32,10 +33,12 @@ import org.apache.doris.nereids.pattern.PatternDescriptor;
|
||||
import org.apache.doris.nereids.pattern.PatternMatcher;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleFactory;
|
||||
import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.OriginStatement;
|
||||
|
||||
@ -43,7 +46,9 @@ import com.google.common.base.Supplier;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Utility to apply rules to plan and check output plan matches the expected pattern.
|
||||
@ -54,6 +59,8 @@ public class PlanChecker {
|
||||
|
||||
private Plan parsedPlan;
|
||||
|
||||
private PhysicalPlan physicalPlan;
|
||||
|
||||
public PlanChecker(ConnectContext connectContext) {
|
||||
this.connectContext = connectContext;
|
||||
}
|
||||
@ -88,7 +95,11 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker applyTopDown(RuleFactory rule) {
|
||||
public PlanChecker applyTopDown(RuleFactory ruleFactory) {
|
||||
return applyTopDown(ruleFactory.buildRules());
|
||||
}
|
||||
|
||||
public PlanChecker applyTopDown(List<Rule> rule) {
|
||||
cascadesContext.topDownRewrite(rule);
|
||||
MemoValidator.validate(cascadesContext.getMemo());
|
||||
return this;
|
||||
@ -132,6 +143,40 @@ public class PlanChecker {
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker rewrite() {
|
||||
new NereidsRewriteJobExecutor(cascadesContext).execute();
|
||||
return this;
|
||||
}
|
||||
|
||||
public PlanChecker implement() {
|
||||
Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot());
|
||||
Assertions.assertTrue(plan instanceof PhysicalPlan);
|
||||
physicalPlan = ((PhysicalPlan) plan);
|
||||
return this;
|
||||
}
|
||||
|
||||
public PhysicalPlan getPhysicalPlan() {
|
||||
return physicalPlan;
|
||||
}
|
||||
|
||||
private Plan transformToPhysicalPlan(Group group) {
|
||||
PhysicalPlan current = null;
|
||||
loop:
|
||||
for (Rule rule : RuleSet.IMPLEMENTATION_RULES) {
|
||||
GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), group.getLogicalExpression());
|
||||
for (Plan plan : matching) {
|
||||
Plan after = rule.transform(plan, cascadesContext).get(0);
|
||||
if (after instanceof PhysicalPlan) {
|
||||
current = (PhysicalPlan) after;
|
||||
break loop;
|
||||
}
|
||||
}
|
||||
}
|
||||
Assertions.assertNotNull(current);
|
||||
return current.withChildren(group.getLogicalExpression().children()
|
||||
.stream().map(this::transformToPhysicalPlan).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
public PlanChecker transform(PatternMatcher patternMatcher) {
|
||||
return transform(cascadesContext.getMemo().getRoot(), patternMatcher);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user