diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/RuntimeFilterTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/RuntimeFilterTranslator.java index 678f574b65..1dc7a0a2c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/RuntimeFilterTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/RuntimeFilterTranslator.java @@ -75,6 +75,7 @@ public class RuntimeFilterTranslator { SlotRef src = ctx.findSlotRef(filter.getSrcExpr().getExprId()); SlotRef target = context.getExprIdToOlapScanNodeSlotRef().get(filter.getTargetExpr().getExprId()); if (target == null) { + context.setTargetNullCount(); return; } org.apache.doris.planner.RuntimeFilter origFilter diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java index 103e5c8c30..5565dcf928 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterContext.java @@ -71,6 +71,8 @@ public class RuntimeFilterContext { private final FilterSizeLimits limits; + private int targetNullCount = 0; + public RuntimeFilterContext(SessionVariable sessionVariable) { this.sessionVariable = sessionVariable; this.limits = new FilterSizeLimits(sessionVariable); @@ -178,4 +180,13 @@ public class RuntimeFilterContext { } while (expr != null); return builder.build(); } + + public void setTargetNullCount() { + targetNullCount++; + } + + @VisibleForTesting + public int getTargetNullCount() { + return targetNullCount; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java index 31fabf3f63..ce2674a854 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java @@ -18,17 +18,15 @@ package org.apache.doris.nereids.postprocess; import org.apache.doris.common.AnalysisException; -import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.datasets.ssb.SSBTestBase; import org.apache.doris.nereids.datasets.ssb.SSBUtils; import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; -import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.processor.post.PlanPostProcessors; import org.apache.doris.nereids.processor.post.RuntimeFilterContext; -import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; -import org.apache.doris.planner.PlanFragment; +import org.apache.doris.nereids.util.PlanChecker; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -49,7 +47,8 @@ public class RuntimeFilterTest extends SSBTestBase { public void testGenerateRuntimeFilter() throws AnalysisException { String sql = "SELECT * FROM lineorder JOIN customer on c_custkey = lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1); + Assertions.assertTrue(filters.size() == 1 + && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); } @Test @@ -64,14 +63,17 @@ public class RuntimeFilterTest extends SSBTestBase { String sql = "SELECT * FROM supplier JOIN customer on c_name = s_name and s_city = c_city and s_nation = c_nation"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3); + Assertions.assertTrue(filters.size() == 3 + && checkRuntimeFilterExprs(filters, "c_name", "s_name", "c_city", "s_city", "c_nation", "s_nation")); } @Test public void testNestedJoinGenerateRuntimeFilter() throws AnalysisException { String sql = SSBUtils.Q4_1; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 4); + Assertions.assertTrue(filters.size() == 4 + && checkRuntimeFilterExprs(filters, "p_partkey", "lo_partkey", "s_suppkey", "lo_suppkey", + "c_custkey", "lo_custkey", "lo_orderdate", "d_datekey")); } @Test @@ -81,7 +83,8 @@ public class RuntimeFilterTest extends SSBTestBase { + " left outer join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3); + Assertions.assertTrue(filters.size() == 2 + && checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey")); } @Test @@ -91,7 +94,8 @@ public class RuntimeFilterTest extends SSBTestBase { + " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3); + Assertions.assertTrue(filters.size() == 1 + && checkRuntimeFilterExprs(filters, "s_suppkey", "c_custkey")); } @Test @@ -102,7 +106,9 @@ public class RuntimeFilterTest extends SSBTestBase { + " inner join (select c_custkey from customer inner join supplier on c_custkey = s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 3); + Assertions.assertTrue(filters.size() == 3 + && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate", + "s_suppkey", "c_custkey")); } @Test @@ -113,21 +119,24 @@ public class RuntimeFilterTest extends SSBTestBase { + " inner join (select sum(c_custkey) c_custkey from customer inner join supplier on c_custkey = s_suppkey group by s_suppkey) b" + " on b.c_custkey = a.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 2); + Assertions.assertTrue(filters.size() == 2 + && checkRuntimeFilterExprs(filters, "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey")); } @Test public void testCrossJoin() throws AnalysisException { String sql = "select c_custkey, lo_custkey from lineorder, customer where lo_custkey = c_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1); + Assertions.assertTrue(filters.size() == 1 + && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); } @Test public void testSubQueryAlias() throws AnalysisException { String sql = "select c_custkey, lo_custkey from lineorder l, customer c where c.c_custkey = l.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 1); + Assertions.assertTrue(filters.size() == 1 + && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey")); } @Test @@ -156,8 +165,9 @@ public class RuntimeFilterTest extends SSBTestBase { + " on t1.p_partkey = t2.lo_partkey\n" + " order by t1.lo_custkey, t1.p_partkey, t2.s_suppkey, t2.c_custkey, t2.lo_orderkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 4); - + Assertions.assertTrue(filters.size() == 4 + && checkRuntimeFilterExprs(filters, "lo_partkey", "p_partkey", "lo_partkey", "p_partkey", + "c_region", "s_region", "lo_custkey", "c_custkey")); } @Test @@ -168,10 +178,11 @@ public class RuntimeFilterTest extends SSBTestBase { + " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder" + " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 5); + Assertions.assertTrue(filters.size() == 5 + && checkRuntimeFilterExprs(filters, "lo_custkey", "c_custkey", "c_custkey", "lo_custkey", + "d_datekey", "lo_orderdate", "s_suppkey", "c_custkey", "lo_custkey", "c_custkey")); } - /* @Test public void testPushDownThroughUnsupportedJoinType() throws AnalysisException { String sql = "select c_custkey from (select c_custkey from (select lo_custkey from lineorder inner join dates" @@ -180,23 +191,33 @@ public class RuntimeFilterTest extends SSBTestBase { + " on b.c_custkey = a.lo_custkey) c inner join (select lo_custkey from customer inner join lineorder" + " on c_custkey = lo_custkey) d on c.c_custkey = d.lo_custkey"; List filters = getRuntimeFilters(sql).get(); - Assertions.assertTrue(filters.size() == 5); + Assertions.assertTrue(filters.size() == 3 + && checkRuntimeFilterExprs(filters, "c_custkey", "lo_custkey", "d_datekey", "lo_orderdate", + "lo_custkey", "c_custkey")); } - */ - private Optional> getRuntimeFilters(String sql) throws AnalysisException { - NereidsPlanner planner = new NereidsPlanner(createStatementCtx(sql)); - PhysicalPlan plan = planner.plan(new NereidsParser().parseSingle(sql), PhysicalProperties.ANY); + private Optional> getRuntimeFilters(String sql) { + PlanChecker checker = PlanChecker.from(connectContext).analyze(sql) + .rewrite() + .implement(); + PhysicalPlan plan = checker.getPhysicalPlan(); + new PlanPostProcessors(checker.getCascadesContext()).process(plan); System.out.println(plan.treeString()); - PlanTranslatorContext context = new PlanTranslatorContext(planner.getCascadesContext()); - PlanFragment root = new PhysicalPlanTranslator().translatePlan(plan, context); - System.out.println(root.getFragmentId()); - if (context.getRuntimeTranslator().isPresent()) { - RuntimeFilterContext ctx = planner.getCascadesContext().getRuntimeFilterContext(); - Assertions.assertEquals(ctx.getNereidsRuntimeFilter().size(), ctx.getLegacyFilters().size()); - return Optional.of(ctx.getNereidsRuntimeFilter()); + new PhysicalPlanTranslator().translatePlan(plan, new PlanTranslatorContext(checker.getCascadesContext())); + RuntimeFilterContext context = checker.getCascadesContext().getRuntimeFilterContext(); + List filters = context.getNereidsRuntimeFilter(); + Assertions.assertEquals(filters.size(), context.getLegacyFilters().size() + context.getTargetNullCount()); + return Optional.of(filters); + } + + private boolean checkRuntimeFilterExprs(List filters, String... colNames) { + int idx = 0; + for (RuntimeFilter filter : filters) { + if (!checkRuntimeFilterExpr(filter, colNames[idx++], colNames[idx++])) { + return false; + } } - return Optional.empty(); + return true; } private boolean checkRuntimeFilterExpr(RuntimeFilter filter, String srcColName, String targetColName) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 769eb34020..ab4f6fabc6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.glue.LogicalPlanAdapter; import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.jobs.batch.NereidsRewriteJobExecutor; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.memo.Memo; @@ -32,10 +33,12 @@ import org.apache.doris.nereids.pattern.PatternDescriptor; import org.apache.doris.nereids.pattern.PatternMatcher; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleFactory; +import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; @@ -43,7 +46,9 @@ import com.google.common.base.Supplier; import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; +import java.util.List; import java.util.function.Consumer; +import java.util.stream.Collectors; /** * Utility to apply rules to plan and check output plan matches the expected pattern. @@ -54,6 +59,8 @@ public class PlanChecker { private Plan parsedPlan; + private PhysicalPlan physicalPlan; + public PlanChecker(ConnectContext connectContext) { this.connectContext = connectContext; } @@ -88,7 +95,11 @@ public class PlanChecker { return this; } - public PlanChecker applyTopDown(RuleFactory rule) { + public PlanChecker applyTopDown(RuleFactory ruleFactory) { + return applyTopDown(ruleFactory.buildRules()); + } + + public PlanChecker applyTopDown(List rule) { cascadesContext.topDownRewrite(rule); MemoValidator.validate(cascadesContext.getMemo()); return this; @@ -132,6 +143,40 @@ public class PlanChecker { return this; } + public PlanChecker rewrite() { + new NereidsRewriteJobExecutor(cascadesContext).execute(); + return this; + } + + public PlanChecker implement() { + Plan plan = transformToPhysicalPlan(cascadesContext.getMemo().getRoot()); + Assertions.assertTrue(plan instanceof PhysicalPlan); + physicalPlan = ((PhysicalPlan) plan); + return this; + } + + public PhysicalPlan getPhysicalPlan() { + return physicalPlan; + } + + private Plan transformToPhysicalPlan(Group group) { + PhysicalPlan current = null; + loop: + for (Rule rule : RuleSet.IMPLEMENTATION_RULES) { + GroupExpressionMatching matching = new GroupExpressionMatching(rule.getPattern(), group.getLogicalExpression()); + for (Plan plan : matching) { + Plan after = rule.transform(plan, cascadesContext).get(0); + if (after instanceof PhysicalPlan) { + current = (PhysicalPlan) after; + break loop; + } + } + } + Assertions.assertNotNull(current); + return current.withChildren(group.getLogicalExpression().children() + .stream().map(this::transformToPhysicalPlan).collect(Collectors.toList())); + } + public PlanChecker transform(PatternMatcher patternMatcher) { return transform(cascadesContext.getMemo().getRoot(), patternMatcher); }