diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index b56cfd4ea2..514469ac76 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -18,10 +18,6 @@ package org.apache.doris.nereids.rules.exploration.mv; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.trees.expressions.Alias; @@ -368,29 +364,11 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate } // Check Aggregate is simple or not and check join is whether valid or not. - // Support join's input can not contain aggregate Only support project, filter, join, logical relation node and - // join condition should be slot reference equals currently + // Support project, filter, join, logical relation node and join condition should only contain + // slot reference equals currently. @Override protected boolean checkPattern(StructInfo structInfo) { - - Plan topPlan = structInfo.getTopPlan(); - Boolean valid = topPlan.accept(StructInfo.AGGREGATE_PATTERN_CHECKER, null); - if (!valid) { - return false; - } - HyperGraph hyperGraph = structInfo.getHyperGraph(); - for (AbstractNode node : hyperGraph.getNodes()) { - StructInfoNode structInfoNode = (StructInfoNode) node; - if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) { - return false; - } - } - for (JoinEdge edge : hyperGraph.getJoinEdges()) { - if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) { - return false; - } - } - return true; + return structInfo.getTopPlan().accept(StructInfo.PLAN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java index c8a64a740f..03549297bf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java @@ -18,10 +18,6 @@ package org.apache.doris.nereids.rules.exploration.mv; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; @@ -79,26 +75,11 @@ public abstract class AbstractMaterializedViewJoinRule extends AbstractMateriali } /** - * Check join is whether valid or not. Support join's input can not contain aggregate - * Only support project, filter, join, logical relation node and - * join condition should be slot reference equals currently + * Check join is whether valid or not. Support join's input only support project, filter, join, + * logical relation node and join condition should be slot reference equals currently */ @Override protected boolean checkPattern(StructInfo structInfo) { - HyperGraph hyperGraph = structInfo.getHyperGraph(); - for (AbstractNode node : hyperGraph.getNodes()) { - StructInfoNode structInfoNode = (StructInfoNode) node; - if (!structInfoNode.getPlan().accept(StructInfo.JOIN_PATTERN_CHECKER, - SUPPORTED_JOIN_TYPE_SET)) { - return false; - } - } - for (JoinEdge edge : hyperGraph.getJoinEdges()) { - if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, - SUPPORTED_JOIN_TYPE_SET)) { - return false; - } - } - return true; + return structInfo.getTopPlan().accept(StructInfo.PLAN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java index 618aaa8684..6a5c82691e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.ObjectId; import org.apache.doris.nereids.trees.plans.Plan; @@ -37,12 +38,11 @@ import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation; import org.apache.doris.nereids.trees.plans.algebra.Filter; import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.algebra.Project; -import org.apache.doris.nereids.trees.plans.algebra.Sort; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer; import org.apache.doris.nereids.util.ExpressionUtils; @@ -57,7 +57,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -68,8 +67,7 @@ import javax.annotation.Nullable; * modify, if wanting to modify, should copy and then modify */ public class StructInfo { - public static final JoinPatternChecker JOIN_PATTERN_CHECKER = new JoinPatternChecker(); - public static final AggregatePatternChecker AGGREGATE_PATTERN_CHECKER = new AggregatePatternChecker(); + public static final PlanPatternChecker PLAN_PATTERN_CHECKER = new PlanPatternChecker(); // struct info splitter public static final PlanSplitter PLAN_SPLITTER = new PlanSplitter(); private static final RelationCollector RELATION_COLLECTOR = new RelationCollector(); @@ -189,7 +187,7 @@ public class StructInfo { Lists.newArrayList(expression), ImmutableSet.of(), ImmutableSet.of()); - topPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext); + structInfoNode.getPlan().accept(ExpressionLineageReplacer.INSTANCE, replaceContext); // Replace expressions by expression map List replacedExpressions = replaceContext.getReplacedExpressions(); shuttledHashConjunctsToConjunctsMap.put(replacedExpressions.get(0), expression); @@ -448,55 +446,58 @@ public class StructInfo { } /** - * JoinPatternChecker + * PlanPatternChecker, this is used to check the plan pattern is valid or not */ - public static class JoinPatternChecker extends DefaultPlanVisitor> { + public static class PlanPatternChecker extends DefaultPlanVisitor> { @Override - public Boolean visit(Plan plan, Set requiredJoinType) { - super.visit(plan, requiredJoinType); - if (!(plan instanceof Filter) - && !(plan instanceof Project) - && !(plan instanceof CatalogRelation) - && !(plan instanceof Join) - && !(plan instanceof Sort) - && !(plan instanceof LogicalAggregate && !((LogicalAggregate) plan).getSourceRepeat() - .isPresent())) { + public Boolean visitLogicalJoin(LogicalJoin join, + Set supportJoinTypes) { + if (!supportJoinTypes.contains(join.getJoinType())) { return false; } - if (plan instanceof Join) { - Join join = (Join) plan; - if (!requiredJoinType.contains(join.getJoinType())) { - return false; - } - if (!join.getOtherJoinConjuncts().isEmpty()) { + if (!join.getOtherJoinConjuncts().isEmpty()) { + return false; + } + return visit(join, supportJoinTypes); + } + + @Override + public Boolean visitLogicalAggregate(LogicalAggregate aggregate, + Set supportJoinTypes) { + if (aggregate.getSourceRepeat().isPresent()) { + return false; + } + return visit(aggregate, supportJoinTypes); + } + + @Override + public Boolean visitGroupPlan(GroupPlan groupPlan, Set supportJoinTypes) { + return groupPlan.getGroup().getLogicalExpressions().stream() + .anyMatch(logicalExpression -> logicalExpression.getPlan().accept(this, supportJoinTypes)); + } + + @Override + public Boolean visit(Plan plan, Set supportJoinTypes) { + if (plan instanceof Filter + || plan instanceof Project + || plan instanceof CatalogRelation + || plan instanceof Join + || plan instanceof LogicalSort + || plan instanceof LogicalAggregate + || plan instanceof GroupPlan) { + return doVisit(plan, supportJoinTypes); + } + return false; + } + + private Boolean doVisit(Plan plan, Set supportJoinTypes) { + for (Plan child : plan.children()) { + boolean valid = child.accept(this, supportJoinTypes); + if (!valid) { return false; } } return true; } } - - /** - * AggregatePatternChecker - */ - public static class AggregatePatternChecker extends DefaultPlanVisitor { - @Override - public Boolean visit(Plan plan, Void context) { - if (plan instanceof LogicalAggregate) { - LogicalAggregate aggregate = (LogicalAggregate) plan; - Optional> sourceRepeat = aggregate.getSourceRepeat(); - if (sourceRepeat.isPresent()) { - return false; - } - super.visit(aggregate, context); - return true; - } - if (plan instanceof Project || plan instanceof Filter || plan instanceof Sort) { - super.visit(plan, context); - return true; - } - super.visit(plan, context); - return false; - } - } } diff --git a/regression-test/data/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.out b/regression-test/data/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.out index 845ef3933d..a1020a1e1e 100644 --- a/regression-test/data/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.out +++ b/regression-test/data/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.out @@ -295,3 +295,17 @@ 4 0 0 0 43.2000 43.20 43.20 5 0 0 0 28.7000 56.20 1.20 +-- !query9_0_before -- +2023-12-08 1 2 +2023-12-09 2 1 +2023-12-10 3 2 +2023-12-11 4 2 +2023-12-12 5 4 + +-- !query9_0_after -- +2023-12-08 1 2 +2023-12-09 2 1 +2023-12-10 3 2 +2023-12-11 4 2 +2023-12-12 5 4 + diff --git a/regression-test/data/nereids_rules_p0/mv/join/left_outer/outer_join.out b/regression-test/data/nereids_rules_p0/mv/join/left_outer/outer_join.out index 0b171742bc..225e336b91 100644 --- a/regression-test/data/nereids_rules_p0/mv/join/left_outer/outer_join.out +++ b/regression-test/data/nereids_rules_p0/mv/join/left_outer/outer_join.out @@ -311,3 +311,17 @@ 4 0 0 0 43.2000 43.20 43.20 5 0 0 0 28.7000 56.20 1.20 +-- !query9_0_before -- +2023-12-08 1 2 +2023-12-09 2 1 +2023-12-10 3 2 +2023-12-11 4 2 +2023-12-12 5 4 + +-- !query9_0_after -- +2023-12-08 1 2 +2023-12-09 2 1 +2023-12-10 3 2 +2023-12-11 4 2 +2023-12-12 5 4 + diff --git a/regression-test/suites/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.groovy b/regression-test/suites/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.groovy index 3d858c57dd..dc9e78e00c 100644 --- a/regression-test/suites/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/join/dphyp_outer/outer_join_dphyp.groovy @@ -466,4 +466,64 @@ suite("outer_join_dphyp") { check_rewrite(mv8_0, query8_0, "mv8_0") order_qt_query8_0_after "${query8_0}" sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0""" + + // join input with simple agg, use aggregate function as outer group by + def mv9_0 = """ + select + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + from + ( + select + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate, + sum(o_shippriority) as col1 + from + orders + group by + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate + ) as t1 + left join lineitem on lineitem.l_orderkey = t1.o_orderkey + group by + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + """ + def query9_0 = """ + select + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + from + ( + select + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate, + sum(o_shippriority) as col1 + from + orders + group by + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate + ) as t1 + left join lineitem on lineitem.l_orderkey = t1.o_orderkey + group by + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + """ + order_qt_query9_0_before "${query9_0}" + check_rewrite(mv9_0, query9_0, "mv9_0") + order_qt_query9_0_after "${query9_0}" + sql """ DROP MATERIALIZED VIEW IF EXISTS mv9_0""" } diff --git a/regression-test/suites/nereids_rules_p0/mv/join/left_outer/outer_join.groovy b/regression-test/suites/nereids_rules_p0/mv/join/left_outer/outer_join.groovy index eadb0df155..6a5bda9d22 100644 --- a/regression-test/suites/nereids_rules_p0/mv/join/left_outer/outer_join.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/join/left_outer/outer_join.groovy @@ -476,14 +476,14 @@ suite("outer_join") { def mv6_1 = """ select l_shipdate, t.o_orderdate, l_partkey, l_suppkey, t.o_orderkey from lineitem_null - left join (select o_orderdate,o_orderkey from orders_null where o_orderdate = '2023-12-10' ) t + left join (select o_orderdate,o_orderkey from orders_null where o_orderdate = '2023-12-10' ) t on l_orderkey = t.o_orderkey; """ def query6_1 = """ - select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey - from lineitem_null - left join orders_null - on l_orderkey = o_orderkey + select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey + from lineitem_null + left join orders_null + on l_orderkey = o_orderkey where l_shipdate = '2023-12-10' and o_orderdate = '2023-12-10'; """ order_qt_query6_1_before "${query6_1}" @@ -494,16 +494,16 @@ suite("outer_join") { // should compensate predicate o_orderdate = '2023-12-10' on mv def mv6_2 = """ - select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey - from lineitem - left join (select * from orders where o_orderdate = '2023-12-10' ) t2 + select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey + from lineitem + left join (select * from orders where o_orderdate = '2023-12-10' ) t2 on lineitem.l_orderkey = t2.o_orderkey; """ def query6_2 = """ - select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey - from lineitem - left join orders - on lineitem.l_orderkey = orders.o_orderkey + select l_shipdate, o_orderdate, l_partkey, l_suppkey, o_orderkey + from lineitem + left join orders + on lineitem.l_orderkey = orders.o_orderkey where o_orderdate = '2023-12-10' order by 1, 2, 3, 4, 5; """ order_qt_query6_2_before "${query6_2}" @@ -586,4 +586,65 @@ suite("outer_join") { check_rewrite(mv8_0, query8_0, "mv8_0") order_qt_query8_0_after "${query8_0}" sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0""" + + + // join input with simple agg, use aggregate function as outer group by + def mv9_0 = """ + select + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + from + ( + select + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate, + sum(o_shippriority) as col1 + from + orders + group by + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate + ) as t1 + left join lineitem on lineitem.l_orderkey = t1.o_orderkey + group by + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + """ + def query9_0 = """ + select + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + from + ( + select + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate, + sum(o_shippriority) as col1 + from + orders + group by + o_orderkey, + o_custkey, + o_orderstatus, + o_orderdate + ) as t1 + left join lineitem on lineitem.l_orderkey = t1.o_orderkey + group by + t1.o_orderdate, + t1.o_orderkey, + t1.col1 + """ + order_qt_query9_0_before "${query9_0}" + check_rewrite(mv9_0, query9_0, "mv9_0") + order_qt_query9_0_after "${query9_0}" + sql """ DROP MATERIALIZED VIEW IF EXISTS mv9_0""" }