diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 9f04ae38cb..7b721a4099 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -580,6 +580,10 @@ public class HyperGraph { LongBitmap.getIterator(addedNodes).forEach(index -> nodes.get(index).attachEdge(edge)); } + public int edgeSize() { + return joinEdges.size() + filterEdges.size(); + } + /** * compare hypergraph * @@ -588,20 +592,37 @@ public class HyperGraph { * be pull up from this hyper graph */ public @Nullable List isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) { - if (viewHG.filterEdges.size() != filterEdges.size() && viewHG.joinEdges.size() != joinEdges.size()) { - return null; - } Map queryToView = constructEdgeMap(viewHG, ctx.getQueryToViewEdgeExpressionMapping()); - if (queryToView.size() != filterEdges.size() + joinEdges.size()) { + + // All edge in view must have a mapped edge in query + if (queryToView.size() != viewHG.edgeSize()) { return null; } + boolean allMatch = queryToView.entrySet().stream() .allMatch(entry -> compareEdgeWithNode(entry.getKey(), entry.getValue(), ctx.getQueryToViewNodeIDMapping())); if (!allMatch) { return null; } - return ImmutableList.of(); + + // join edges must be identical + boolean isJoinIdentical = joinEdges.stream() + .allMatch(queryToView::containsKey); + if (!isJoinIdentical) { + return null; + } + + // extract all top filters + List residualFilterEdges = filterEdges.stream() + .filter(e -> !queryToView.containsKey(e)) + .collect(ImmutableList.toImmutableList()); + if (residualFilterEdges.stream().anyMatch(e -> !e.isTopFilter())) { + return null; + } + return residualFilterEdges.stream() + .flatMap(e -> e.getExpressions().stream()) + .collect(ImmutableList.toImmutableList()); } private Map constructEdgeMap(HyperGraph viewHG, Map exprMap) { @@ -622,13 +643,19 @@ public class HyperGraph { private boolean compareEdgeWithNode(Edge t, Edge o, Map nodeMap) { if (t instanceof FilterEdge && o instanceof FilterEdge) { - return false; + return compareEdgeWithFilter((FilterEdge) t, (FilterEdge) o, nodeMap); } else if (t instanceof JoinEdge && o instanceof JoinEdge) { return compareJoinEdge((JoinEdge) t, (JoinEdge) o, nodeMap); } return false; } + private boolean compareEdgeWithFilter(FilterEdge t, FilterEdge o, Map nodeMap) { + long tChild = t.getReferenceNodes(); + long oChild = o.getReferenceNodes(); + return compareNodeMap(tChild, oChild, nodeMap); + } + private boolean compareJoinEdge(JoinEdge t, JoinEdge o, Map nodeMap) { long tLeft = t.getLeftExtendedNodes(); long tRight = t.getRightExtendedNodes(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java index d04031067d..ec03787102 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java @@ -49,6 +49,10 @@ public class FilterEdge extends Edge { return rejectEdges; } + public boolean isTopFilter() { + return rejectEdges.isEmpty(); + } + @Override public Set getInputSlots() { return filter.getInputSlots(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java index f8ba3ead2e..d0e5084a1c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/CompareOuterJoinTest.java @@ -25,13 +25,17 @@ import org.apache.doris.nereids.rules.exploration.mv.StructInfo; import org.apache.doris.nereids.rules.exploration.mv.mapping.RelationMapping; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.sqltest.SqlTestBase; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.util.HyperGraphBuilder; import org.apache.doris.nereids.util.PlanChecker; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import java.util.List; + class CompareOuterJoinTest extends SqlTestBase { @Test void testStarGraphWithInnerJoin() { @@ -50,12 +54,12 @@ class CompareOuterJoinTest extends SqlTestBase { Plan p1 = PlanChecker.from(c1) .analyze() .rewrite() - .getPlan(); + .getPlan().child(0); Plan p2 = PlanChecker.from(c1) .analyze() .rewrite() .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) - .getAllPlan().get(0); + .getAllPlan().get(0).child(0); HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); Assertions.assertTrue(h1.isLogicCompatible(h2, constructContext(p1, p2)) != null); @@ -78,6 +82,115 @@ class CompareOuterJoinTest extends SqlTestBase { Assertions.assertTrue(h1.isLogicCompatible(h2, constructContext(p1, p2)) != null); } + @Test + void testInnerJoinWithFilter() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T2 on T1.id = T2.id where T1.id = 0", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 inner join T2 on T1.id = T2.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + List exprList = h1.isLogicCompatible(h2, constructContext(p1, p2)); + Assertions.assertEquals(1, exprList.size()); + Assertions.assertEquals("(id = 0)", exprList.get(0).toSql()); + } + + @Disabled + @Test + void testInnerJoinWithFilter2() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 inner join T2 on T1.id = T2.id where T1.id = 0", + connectContext + ); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 inner join T2 on T1.id = T2.id where T1.id = 0", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + List exprList = h1.isLogicCompatible(h2, constructContext(p1, p2)); + Assertions.assertEquals(0, exprList.size()); + } + + @Test + void testLeftOuterJoinWithLeftFilter() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from ( select * from T1 where T1.id = 0) T1 left outer join T2 on T1.id = T2.id", + connectContext + ); + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join T2 on T1.id = T2.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + List exprList = h1.isLogicCompatible(h2, constructContext(p1, p2)); + Assertions.assertEquals(1, exprList.size()); + Assertions.assertEquals("(id = 0)", exprList.get(0).toSql()); + } + + @Test + void testLeftOuterJoinWithRightFilter() { + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + CascadesContext c1 = createCascadesContext( + "select * from T1 left outer join ( select * from T2 where T2.id = 0) T2 on T1.id = T2.id", + connectContext + ); + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES"); + Plan p1 = PlanChecker.from(c1) + .analyze() + .rewrite() + .getPlan().child(0); + CascadesContext c2 = createCascadesContext( + "select * from T1 left outer join T2 on T1.id = T2.id", + connectContext + ); + Plan p2 = PlanChecker.from(c2) + .analyze() + .rewrite() + .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER) + .getAllPlan().get(0).child(0); + HyperGraph h1 = HyperGraph.toStructInfo(p1).get(0); + HyperGraph h2 = HyperGraph.toStructInfo(p2).get(0); + List exprList = h1.isLogicCompatible(h2, constructContext(p1, p2)); + Assertions.assertEquals(null, exprList); + } + LogicalCompatibilityContext constructContext(Plan p1, Plan p2) { StructInfo st1 = AbstractMaterializedViewRule.extractStructInfo(p1, null).get(0);