diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java index 30115a0f57..2ec0cb3a6f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPushDownVisitor.java @@ -182,7 +182,9 @@ public class RuntimeFilterPushDownVisitor extends PlanVisitor join, PushDownContext ctx) { + if (!ctx.builderNode.equals(join) + && !join.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return false; + } boolean pushed = false; if (ctx.builderNode instanceof PhysicalHashJoin) { @@ -314,6 +320,9 @@ public class RuntimeFilterPushDownVisitor extends PlanVisitor join, PushDownContext ctx) { + if (!join.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return false; + } if (ctx.builderNode instanceof PhysicalHashJoin) { /* hashJoin( t1.A <=> t2.A ) @@ -342,6 +351,9 @@ public class RuntimeFilterPushDownVisitor extends PlanVisitor project, PushDownContext ctx) { + if (!project.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return false; + } // project ( A+1 as x) // probeExpr: abs(x) => abs(A+1) PushDownContext ctxProjectProbeExpr = ctx; @@ -384,6 +396,9 @@ public class RuntimeFilterPushDownVisitor extends PlanVisitor window, PushDownContext ctx) { + if (!window.getOutputSet().containsAll(ctx.probeExpr.getInputSlots())) { + return false; + } + Set commonPartitionKeys = window.getCommonPartitionKeyFromWindowExpressions(); if (commonPartitionKeys.containsAll(ctx.probeExpr.getInputSlots())) { return window.child().accept(this, ctx); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java index f354ed5b02..8f9e1a5b6e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/RuntimeFilterTest.java @@ -22,9 +22,23 @@ import org.apache.doris.nereids.datasets.ssb.SSBTestBase; import org.apache.doris.nereids.datasets.ssb.SSBUtils; import org.apache.doris.nereids.glue.translator.PhysicalPlanTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; +import org.apache.doris.nereids.hint.DistributeHint; import org.apache.doris.nereids.processor.post.PlanPostProcessors; import org.apache.doris.nereids.processor.post.RuntimeFilterContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.plans.DistributeType; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.RuntimeFilter; import org.apache.doris.nereids.util.PlanChecker; @@ -33,6 +47,7 @@ import com.google.common.collect.Sets; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; @@ -344,4 +359,65 @@ public class RuntimeFilterTest extends SSBTestBase { Assertions.assertEquals(0, filters.size()); } + @Test + public void testNotGenerateRfOnDanglingSlot() { + String sql = "select lo_custkey from lineorder union all select c_custkey from customer union all select p_partkey from part;"; + PlanChecker checker = PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .implement(); + PhysicalPlan plan = checker.getPhysicalPlan(); + + /* construct plan for + join (#18=p_partkey) + -->join() + -->project(null as #18, ...) + -->lineorder + -->project(c_custkey#17) + -->customer(output: c_custkey#17, c_name#18, ...) + -->project(p_partkey#25) + -->part + + test purpose: + do not generate RF by "#18=p_partkey" and apply this rf on customer + */ + PhysicalProject projectCustomer = (PhysicalProject) plan.child(0).child(1); + SlotReference cCustkey = (SlotReference) projectCustomer.getProjects().get(0); + PhysicalProject projectPart = (PhysicalProject) plan.child(0).child(2); + SlotReference pPartkey = (SlotReference) projectPart.getProjects().get(0); + + PhysicalOlapScan lo = (PhysicalOlapScan) plan.child(0).child(0).child(0); + SlotReference loCustkey = (SlotReference) lo.getBaseOutputs().get(2); + SlotReference loPartkey = (SlotReference) lo.getBaseOutputs().get(3); + Alias nullAlias = new Alias(new ExprId(18), new NullLiteral(), ""); // expr#18 is used by c_name + List projList = new ArrayList<>(); + projList.add(loCustkey); + projList.add(loPartkey); + projList.add(nullAlias); + PhysicalProject projLo = new PhysicalProject(projList, null, lo); + + PhysicalHashJoin joinLoC = new PhysicalHashJoin(JoinType.INNER_JOIN, + ImmutableList.of(new EqualTo(loCustkey, cCustkey)), + ImmutableList.of(), + new DistributeHint(DistributeType.NONE), + Optional.empty(), + null, + projLo, + projectCustomer + ); + PhysicalHashJoin joinLoCP = new PhysicalHashJoin(JoinType.INNER_JOIN, + ImmutableList.of(new EqualTo(nullAlias.toSlot(), pPartkey)), + ImmutableList.of(), + new DistributeHint(DistributeType.NONE), + Optional.empty(), + null, + joinLoC, + projectPart + ); + checker.getCascadesContext().getConnectContext().getSessionVariable().enableRuntimeFilterPrune = false; + plan = new PlanPostProcessors(checker.getCascadesContext()).process(joinLoCP); + System.out.println(plan.treeString()); + Assertions.assertEquals(0, ((AbstractPhysicalPlan) plan.child(0).child(1).child(0)) + .getAppliedRuntimeFilters().size()); + } }