diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/CBOUtils.java similarity index 93% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/CBOUtils.java index c49b63cecb..4ab61cfbee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/CBOUtils.java @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.exploration.join; +package org.apache.doris.nereids.rules.exploration; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -35,8 +34,8 @@ import java.util.stream.Collectors; /** * Common */ -class JoinReorderUtils { - static boolean isAllSlotProject(LogicalProject> project) { +public class CBOUtils { + public static boolean isAllSlotProject(LogicalProject project) { return project.getProjects().stream().allMatch(expr -> expr instanceof Slot); } @@ -44,7 +43,7 @@ class JoinReorderUtils { * Split project according to whether namedExpr contains by splitChildExprIds. * Notice: projects must all be Slot. */ - static Map> splitProject(List projects, + public static Map> splitProject(List projects, Set splitChildExprIds) { return projects.stream() .collect(Collectors.partitioningBy(expr -> { @@ -102,6 +101,9 @@ class JoinReorderUtils { return new LogicalProject<>(projects, plan); } + /** + * Split topJoin Condition to two part according to include bExprIdSet. + */ public static Map> splitConjuncts(List topConjuncts, List bottomConjuncts, Set bExprIdSet) { // top: (A B)(error) (A C) (B C) (A B C) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java index cde94b33eb..09f0f79a9a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java @@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; @@ -57,52 +58,68 @@ import java.util.Set; * (x) * */ -public class EagerCount extends OneExplorationRuleFactory { - public static final EagerCount INSTANCE = new EagerCount(); - +public class EagerCount implements ExplorationRuleFactory { @Override - public Rule build() { - return logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) - .when(agg -> agg.getAggregateFunctions().stream() - .allMatch(f -> f instanceof Sum - && ((Sum) f).child() instanceof SlotReference - && agg.child().left().getOutputSet().contains((SlotReference) ((Sum) f).child()))) - .then(agg -> { - LogicalJoin join = agg.child(); - List rightOutput = join.right().getOutput(); + public List buildRules() { + return ImmutableList.of( + logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum + && ((Sum) f).child() instanceof SlotReference + && agg.child().left().getOutputSet() + .contains((SlotReference) ((Sum) f).child()))) + .then(agg -> eagerCount(agg, agg.child(), ImmutableList.of())) + .toRule(RuleType.EAGER_COUNT), + logicalAggregate(logicalProject(innerLogicalJoin())) + .when(agg -> CBOUtils.isAllSlotProject(agg.child())) + .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum + && ((Sum) f).child() instanceof SlotReference + && agg.child().child().left().getOutputSet() + .contains((SlotReference) ((Sum) f).child()))) + .then(agg -> eagerCount(agg, agg.child().child(), agg.child().getProjects())) + .toRule(RuleType.EAGER_COUNT) + ); + } - Set cntAggGroupBy = new HashSet<>(); - agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains) - .forEach(cntAggGroupBy::add); - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (rightOutput.contains(slot)) { - cntAggGroupBy.add(slot); - } - })); - Alias cnt = new Alias(new Count(Literal.of(1)), "cnt"); - List cntAggOutput = ImmutableList.builder() - .addAll(cntAggGroupBy).add(cnt).build(); - LogicalAggregate cntAgg = new LogicalAggregate<>( - ImmutableList.copyOf(cntAggGroupBy), cntAggOutput, join.right()); - Plan newJoin = join.withChildren(join.left(), cntAgg); + private LogicalAggregate eagerCount(LogicalAggregate agg, + LogicalJoin join, List projects) { + List rightOutput = join.right().getOutput(); - List newOutputExprs = new ArrayList<>(); - List sumOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { - sumOutputExprs.add((Alias) ne); - } else { - newOutputExprs.add(ne); - } - } - for (Alias oldSum : sumOutputExprs) { - Sum oldSumFunc = (Sum) oldSum.child(); - newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()), - oldSum.getName())); - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - }).toRule(RuleType.EAGER_COUNT); + Set cntAggGroupBy = new HashSet<>(); + agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains) + .forEach(cntAggGroupBy::add); + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (rightOutput.contains(slot)) { + cntAggGroupBy.add(slot); + } + })); + Alias cnt = new Alias(new Count(Literal.of(1)), "cnt"); + List cntAggOutput = ImmutableList.builder() + .addAll(cntAggGroupBy).add(cnt).build(); + LogicalAggregate cntAgg = new LogicalAggregate<>( + ImmutableList.copyOf(cntAggGroupBy), cntAggOutput, join.right()); + Plan newJoin = join.withChildren(join.left(), cntAgg); + + List newOutputExprs = new ArrayList<>(); + List sumOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { + sumOutputExprs.add((Alias) ne); + } else { + newOutputExprs.add(ne); + } + } + for (Alias oldSum : sumOutputExprs) { + Sum oldSumFunc = (Sum) oldSum.child(); + newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()), + oldSum.getName())); + } + Plan child = PlanUtils.projectOrSelf(projects, newJoin); + return agg.withAggOutputChild(newOutputExprs, child); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java index 27fcd149b2..22e7d5194e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; @@ -57,59 +58,75 @@ import java.util.stream.Collectors; * After Eager Group By, new plan also can apply `Eager Count`. * It's `Double Eager`. */ -public class EagerGroupBy extends OneExplorationRuleFactory { - public static final EagerGroupBy INSTANCE = new EagerGroupBy(); - +public class EagerGroupBy implements ExplorationRuleFactory { @Override - public Rule build() { - return logicalAggregate(innerLogicalJoin()) - .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) - .when(agg -> agg.getAggregateFunctions().stream() - .allMatch(f -> f instanceof Sum - && ((Sum) f).child() instanceof SlotReference - && agg.child().left().getOutputSet().contains((SlotReference) ((Sum) f).child()))) - .then(agg -> { - LogicalJoin join = agg.child(); - List leftOutput = join.left().getOutput(); - List sums = agg.getAggregateFunctions().stream().map(Sum.class::cast) - .collect(Collectors.toList()); + public List buildRules() { + return ImmutableList.of( + logicalAggregate(innerLogicalJoin()) + .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum + && ((Sum) f).child() instanceof SlotReference + && agg.child().left().getOutputSet() + .contains((SlotReference) ((Sum) f).child()))) + .then(agg -> eagerGroupBy(agg, agg.child(), ImmutableList.of())) + .toRule(RuleType.EAGER_GROUP_BY), + logicalAggregate(logicalProject(innerLogicalJoin())) + .when(agg -> CBOUtils.isAllSlotProject(agg.child())) + .when(agg -> agg.child().child().getOtherJoinConjuncts().size() == 0) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot)) + .when(agg -> agg.getAggregateFunctions().stream() + .allMatch(f -> f instanceof Sum + && ((Sum) f).child() instanceof SlotReference + && agg.child().child().left().getOutputSet() + .contains((SlotReference) ((Sum) f).child()))) + .then(agg -> eagerGroupBy(agg, agg.child().child(), agg.child().getProjects())) + .toRule(RuleType.EAGER_GROUP_BY) + ); + } - // eager group-by - Set sumAggGroupBy = new HashSet<>(); - agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains) - .forEach(sumAggGroupBy::add); - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - sumAggGroupBy.add(slot); - } - })); - List bottomSums = new ArrayList<>(); - for (int i = 0; i < sums.size(); i++) { - bottomSums.add(new Alias(new Sum(sums.get(i).child()), "sum" + i)); - } - List sumAggOutput = ImmutableList.builder() - .addAll(sumAggGroupBy).addAll(bottomSums).build(); - LogicalAggregate sumAgg = new LogicalAggregate<>( - ImmutableList.copyOf(sumAggGroupBy), sumAggOutput, join.left()); - Plan newJoin = join.withChildren(sumAgg, join.right()); + private LogicalAggregate eagerGroupBy(LogicalAggregate agg, + LogicalJoin join, List projects) { + List leftOutput = join.left().getOutput(); + List sums = agg.getAggregateFunctions().stream().map(Sum.class::cast) + .collect(Collectors.toList()); - List newOutputExprs = new ArrayList<>(); - List sumOutputExprs = new ArrayList<>(); - for (NamedExpression ne : agg.getOutputExpressions()) { - if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { - sumOutputExprs.add((Alias) ne); - } else { - newOutputExprs.add(ne); - } - } - for (int i = 0; i < sumOutputExprs.size(); i++) { - Alias oldSum = sumOutputExprs.get(i); - // sum in bottom Agg - Slot bottomSum = bottomSums.get(i).toSlot(); - Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName()); - newOutputExprs.add(newSum); - } - return agg.withAggOutputChild(newOutputExprs, newJoin); - }).toRule(RuleType.EAGER_GROUP_BY); + // eager group-by + Set sumAggGroupBy = new HashSet<>(); + agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains) + .forEach(sumAggGroupBy::add); + join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { + if (leftOutput.contains(slot)) { + sumAggGroupBy.add(slot); + } + })); + List bottomSums = new ArrayList<>(); + for (int i = 0; i < sums.size(); i++) { + bottomSums.add(new Alias(new Sum(sums.get(i).child()), "sum" + i)); + } + List sumAggOutput = ImmutableList.builder() + .addAll(sumAggGroupBy).addAll(bottomSums).build(); + LogicalAggregate sumAgg = new LogicalAggregate<>( + ImmutableList.copyOf(sumAggGroupBy), sumAggOutput, join.left()); + Plan newJoin = join.withChildren(sumAgg, join.right()); + + List newOutputExprs = new ArrayList<>(); + List sumOutputExprs = new ArrayList<>(); + for (NamedExpression ne : agg.getOutputExpressions()) { + if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) { + sumOutputExprs.add((Alias) ne); + } else { + newOutputExprs.add(ne); + } + } + for (int i = 0; i < sumOutputExprs.size(); i++) { + Alias oldSum = sumOutputExprs.get(i); + // sum in bottom Agg + Slot bottomSum = bottomSums.get(i).toSlot(); + Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName()); + newOutputExprs.add(newSum); + } + Plan child = PlanUtils.projectOrSelf(projects, newJoin); + return agg.withAggOutputChild(newOutputExprs, child); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index 9733191605..4c6147ae00 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -55,7 +56,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ LogicalJoin bottomJoin = topJoin.left().child(); @@ -89,15 +90,15 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, b); + Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); + Plan right = CBOUtils.newProject(topUsedExprIds, b); LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(newTopHashConjuncts, newTopOtherConjuncts, left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_INNER_JOIN_LASSCOM_PROJECT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLeftAssociateProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLeftAssociateProject.java index c70ec914a9..4dd425d0ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLeftAssociateProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLeftAssociateProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -51,7 +52,7 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory { .when(InnerJoinLeftAssociate::checkReorder) .whenNot(join -> join.hasJoinHint() || join.right().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.right().child().isMarkJoin()) - .when(join -> JoinReorderUtils.isAllSlotProject(join.right())) + .when(join -> CBOUtils.isAllSlotProject(join.right())) .then(topJoin -> { LogicalJoin bottomJoin = topJoin.right().child(); GroupPlan a = topJoin.left(); @@ -60,11 +61,11 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory { Set cExprIdSet = c.getOutputExprIdSet(); // Split condition - Map> splitHashConjuncts = JoinReorderUtils.splitConjuncts( + Map> splitHashConjuncts = CBOUtils.splitConjuncts( topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), cExprIdSet); List newTopHashConjuncts = splitHashConjuncts.get(true); List newBottomHashConjuncts = splitHashConjuncts.get(false); - Map> splitOtherConjuncts = JoinReorderUtils.splitConjuncts( + Map> splitOtherConjuncts = CBOUtils.splitConjuncts( topJoin.getOtherJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), cExprIdSet); List newTopOtherConjuncts = splitOtherConjuncts.get(true); List newBottomOtherConjuncts = splitOtherConjuncts.get(false); @@ -81,15 +82,15 @@ public class InnerJoinLeftAssociateProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, c); + Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin); + Plan right = CBOUtils.newProject(topUsedExprIds, c); LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren( newTopHashConjuncts, newTopOtherConjuncts, left, right); InnerJoinLeftAssociate.setNewBottomJoinReorder(newBottomJoin, bottomJoin); InnerJoinLeftAssociate.setNewTopJoinReorder(newTopJoin, topJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinRightAssociateProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinRightAssociateProject.java index fcce515d81..cfc4364a38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinRightAssociateProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinRightAssociateProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -49,7 +50,7 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory { .when(InnerJoinRightAssociate::checkReorder) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topJoin -> { LogicalJoin bottomJoin = topJoin.left().child(); GroupPlan a = bottomJoin.left(); @@ -58,11 +59,11 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory { Set aExprIdSet = a.getOutputExprIdSet(); // Split condition - Map> splitHashConjuncts = JoinReorderUtils.splitConjuncts( + Map> splitHashConjuncts = CBOUtils.splitConjuncts( topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), aExprIdSet); List newTopHashConjuncts = splitHashConjuncts.get(true); List newBottomHashConjuncts = splitHashConjuncts.get(false); - Map> splitOtherConjuncts = JoinReorderUtils.splitConjuncts( + Map> splitOtherConjuncts = CBOUtils.splitConjuncts( topJoin.getOtherJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), aExprIdSet); List newTopOtherConjuncts = splitOtherConjuncts.get(true); List newBottomOtherConjuncts = splitOtherConjuncts.get(false); @@ -78,15 +79,15 @@ public class InnerJoinRightAssociateProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopHashConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopOtherConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, a); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, newBottomJoin); + Plan left = CBOUtils.newProject(topUsedExprIds, a); + Plan right = CBOUtils.newProject(topUsedExprIds, newBottomJoin); LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren( newTopHashConjuncts, newTopOtherConjuncts, left, right); setNewBottomJoinReorder(newBottomJoin, bottomJoin); setNewTopJoinReorder(newTopJoin, topJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeBothProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeBothProject.java index a921ef6fcb..0543fcefb7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeBothProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeBothProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -53,8 +54,8 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory { public Rule build() { return innerLogicalJoin(logicalProject(innerLogicalJoin()), logicalProject(innerLogicalJoin())) .when(JoinExchange::checkReorder) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left()) - && JoinReorderUtils.isAllSlotProject(join.right())) + .when(join -> CBOUtils.isAllSlotProject(join.left()) + && CBOUtils.isAllSlotProject(join.right())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint() || join.right().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().child().isMarkJoin()) @@ -95,8 +96,8 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin); + Plan left = CBOUtils.newProject(topUsedExprIds, newLeftJoin); + Plan right = CBOUtils.newProject(topUsedExprIds, newRightJoin); LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN, newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE, left, right); @@ -104,7 +105,7 @@ public class JoinExchangeBothProject extends OneExplorationRuleFactory { JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin); JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_JOIN_EXCHANGE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeLeftProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeLeftProject.java index 73f2bfcc2b..9f8013f1d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeLeftProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeLeftProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -53,7 +54,7 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory { public Rule build() { return innerLogicalJoin(logicalProject(innerLogicalJoin()), innerLogicalJoin()) .when(JoinExchange::checkReorder) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint() || join.right().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin() || join.right().isMarkJoin()) @@ -94,8 +95,8 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin); + Plan left = CBOUtils.newProject(topUsedExprIds, newLeftJoin); + Plan right = CBOUtils.newProject(topUsedExprIds, newRightJoin); LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN, newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE, left, right); @@ -103,7 +104,7 @@ public class JoinExchangeLeftProject extends OneExplorationRuleFactory { JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin); JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_JOIN_EXCHANGE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeRightProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeRightProject.java index d94b0fab69..f5df8917c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeRightProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchangeRightProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; @@ -53,7 +54,7 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory { public Rule build() { return innerLogicalJoin(innerLogicalJoin(), logicalProject(innerLogicalJoin())) .when(JoinExchange::checkReorder) - .when(join -> JoinReorderUtils.isAllSlotProject(join.right())) + .when(join -> CBOUtils.isAllSlotProject(join.right())) .whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint() || join.right().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin() || join.right().child().isMarkJoin()) @@ -94,8 +95,8 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory { Set topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet()); newTopJoinHashJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); newTopJoinOtherJoinConjuncts.forEach(expr -> topUsedExprIds.addAll(expr.getInputSlotExprIds())); - Plan left = JoinReorderUtils.newProject(topUsedExprIds, newLeftJoin); - Plan right = JoinReorderUtils.newProject(topUsedExprIds, newRightJoin); + Plan left = CBOUtils.newProject(topUsedExprIds, newLeftJoin); + Plan right = CBOUtils.newProject(topUsedExprIds, newRightJoin); LogicalJoin newTopJoin = new LogicalJoin<>(JoinType.INNER_JOIN, newTopJoinHashJoinConjuncts, newTopJoinOtherJoinConjuncts, JoinHint.NONE, left, right); @@ -103,7 +104,7 @@ public class JoinExchangeRightProject extends OneExplorationRuleFactory { JoinExchange.setNewRightJoinReorder(newRightJoin, leftJoin); JoinExchange.setNewTopJoinReorder(newTopJoin, topJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_JOIN_EXCHANGE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java index 064b79dbb6..bcb66436b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/LogicalJoinSemiJoinTransposeProject.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; @@ -45,7 +46,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto || topJoin.getJoinType().isLeftOuterJoin()))) .whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().child().hasJoinHint()) .whenNot(LogicalJoin::isMarkJoin) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topJoin -> { LogicalJoin bottomJoin = topJoin.left().child(); GroupPlan a = bottomJoin.left(); @@ -55,7 +56,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto // Discard this project, because it is useless. Plan newBottomJoin = topJoin.withChildrenNoContext(a, c); Plan newTopJoin = bottomJoin.withChildrenNoContext(newBottomJoin, b); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT), @@ -63,7 +64,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto .when(topJoin -> (topJoin.right().child().getJoinType().isLeftSemiOrAntiJoin() && (topJoin.getJoinType().isInnerJoin() || topJoin.getJoinType().isRightOuterJoin()))) - .when(join -> JoinReorderUtils.isAllSlotProject(join.right())) + .when(join -> CBOUtils.isAllSlotProject(join.right())) .then(topJoin -> { LogicalJoin bottomJoin = topJoin.right().child(); GroupPlan a = topJoin.left(); @@ -73,7 +74,7 @@ public class LogicalJoinSemiJoinTransposeProject implements ExplorationRuleFacto // Discard this project, because it is useless. Plan newBottomJoin = topJoin.withChildrenNoContext(a, b); Plan newTopJoin = bottomJoin.withChildrenNoContext(newBottomJoin, c); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java index 561f03998e..286d03c505 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -58,7 +59,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet())) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ List projects = topJoin.left().getProjects(); @@ -69,7 +70,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { Set aOutputExprIds = a.getOutputExprIdSet(); /* ********** Split projects ********** */ - Map> map = JoinReorderUtils.splitProject(projects, aOutputExprIds); + Map> map = CBOUtils.splitProject(projects, aOutputExprIds); List aProjects = map.get(true); List bProjects = map.get(false); if (bProjects.isEmpty()) { @@ -91,22 +92,22 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { .flatMap(onExpr -> onExpr.getInputSlots().stream()) .collect(Collectors.partitioningBy( slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet())); - JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); - JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); + CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); + CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet())); /* ********** new Plan ********** */ LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c); newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext()); - Plan left = JoinReorderUtils.projectOrSelf(aProjects, a); - Plan right = JoinReorderUtils.projectOrSelf(bProjects, newBottomJoin); + Plan left = CBOUtils.projectOrSelf(aProjects, a); + Plan right = CBOUtils.projectOrSelf(bProjects, newBottomJoin); LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); OuterJoinAssoc.setReorderContext(newTopJoin, newBottomJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_OUTER_JOIN_ASSOC_PROJECT); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java index 2b52cd9157..72d0982edd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -60,7 +61,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ List projects = topJoin.left().getProjects(); @@ -71,7 +72,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { Set aOutputExprIds = a.getOutputExprIdSet(); /* ********** Split projects ********** */ - Map> map = JoinReorderUtils.splitProject(projects, aOutputExprIds); + Map> map = CBOUtils.splitProject(projects, aOutputExprIds); List aProjects = map.get(true); if (aProjects.isEmpty()) { return null; @@ -93,8 +94,8 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { .flatMap(onExpr -> onExpr.getInputSlots().stream()) .collect(Collectors.partitioningBy( slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet())); - JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); - JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); + CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); + CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); aProjects.addAll(forceToNullable(c.getOutputSet())); @@ -104,14 +105,14 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); - Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin); - Plan right = JoinReorderUtils.projectOrSelf(bProjects, b); + Plan left = CBOUtils.projectOrSelf(aProjects, newBottomJoin); + Plan right = CBOUtils.projectOrSelf(bProjects, b); LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java index fb98df8408..bc1c45f243 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -51,7 +52,7 @@ public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory { @Override public Rule build() { return logicalProject(logicalJoin()) - .whenNot(JoinReorderUtils::isAllSlotProject) + .whenNot(CBOUtils::isAllSlotProject) .when(project -> project.child().getJoinType().isInnerJoin()) .whenNot(project -> project.child().hasJoinHint()) .then(project -> { @@ -87,24 +88,24 @@ public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory { } Builder newAProject = ImmutableList.builder().addAll(aProjects); - Set aConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, true); + Set aConditionSlots = CBOUtils.joinChildConditionSlots(join, true); Set aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add); - Plan newLeft = JoinReorderUtils.projectOrSelf(newAProject.build(), join.left()); + Plan newLeft = CBOUtils.projectOrSelf(newAProject.build(), join.left()); if (!rightContains) { Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); } Builder newBProject = ImmutableList.builder().addAll(bProjects); - Set bConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, false); + Set bConditionSlots = CBOUtils.joinChildConditionSlots(join, false); Set bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add); - Plan newRight = JoinReorderUtils.projectOrSelf(newBProject.build(), join.right()); + Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right()); Plan newJoin = join.withChildrenNoContext(newLeft, newRight); - return JoinReorderUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin); + return CBOUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin); }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java index b6316511da..c248874589 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -51,20 +52,20 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory { return logicalProject(logicalJoin()) .when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin()) // Just pushdown project with non-column expr like (t.id + 1) - .whenNot(JoinReorderUtils::isAllSlotProject) + .whenNot(CBOUtils::isAllSlotProject) .whenNot(project -> project.child().hasJoinHint()) .then(project -> { LogicalJoin join = project.child(); - Set conditionLeftSlots = JoinReorderUtils.joinChildConditionSlots(join, true); + Set conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true); List newProject = new ArrayList<>(project.getProjects()); Set projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot) .collect(Collectors.toSet()); conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add); - Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left()); + Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left()); Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); - return JoinReorderUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin); + return CBOUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin); }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index bc24162dec..85be57370d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -55,7 +56,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) - .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) + .when(join -> CBOUtils.isAllSlotProject(join.left())) .then(topSemi -> { LogicalJoin bottomSemi = topSemi.left().child(); LogicalProject abProject = topSemi.left(); @@ -80,7 +81,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory LogicalJoin newTopSemi = bottomSemi.withChildrenNoContext(acProject, b); newTopSemi.getJoinReorderContext().copyFrom(topSemi.getJoinReorderContext()); newTopSemi.getJoinReorderContext().setHasLAsscom(true); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemi.getOutput()), newTopSemi); + return CBOUtils.projectOrSelf(new ArrayList<>(topSemi.getOutput()), newTopSemi); }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index f68664bc32..500abc433a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -19,12 +19,15 @@ package org.apache.doris.nereids.util; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.Sets; +import java.util.List; import java.util.Optional; import java.util.Set; @@ -54,4 +57,14 @@ public class PlanUtils { return buffer.isEmpty() ? expression : expression.commute(); } + public static Optional> project(List projects, Plan plan) { + if (projects.isEmpty()) { + return Optional.empty(); + } + return Optional.of(new LogicalProject<>(projects, plan)); + } + + public static Plan projectOrSelf(List projects, Plan plan) { + return project(projects, plan).map(Plan.class::cast).orElse(plan); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerCountTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerCountTest.java index dea0f77814..ca65e4be78 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerCountTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerCountTest.java @@ -48,7 +48,7 @@ class EagerCountTest implements MemoPatternMatchSupported { ImmutableList.of(new Alias(new Sum(scan1.getOutput().get(1)), "sum"))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), agg) - .applyExploration(EagerCount.INSTANCE.build()) + .applyExploration(new EagerCount().buildRules()) .matchesExploration( logicalAggregate( logicalJoin( @@ -72,7 +72,7 @@ class EagerCountTest implements MemoPatternMatchSupported { )) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), agg) - .applyExploration(EagerCount.INSTANCE.build()) + .applyExploration(new EagerCount().buildRules()) .printlnOrigin() .matchesExploration( logicalAggregate( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByTest.java index 470b0a752a..afa73ae551 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByTest.java @@ -48,7 +48,7 @@ class EagerGroupByTest implements MemoPatternMatchSupported { ImmutableList.of(new Alias(new Sum(scan1.getOutput().get(3)), "sum"))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), agg) - .applyExploration(EagerGroupBy.INSTANCE.build()) + .applyExploration(new EagerGroupBy().buildRules()) .matchesExploration( logicalAggregate( logicalJoin( @@ -73,7 +73,7 @@ class EagerGroupByTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), agg) - .applyExploration(EagerGroupBy.INSTANCE.build()) + .applyExploration(new EagerGroupBy().buildRules()) .matchesExploration( logicalAggregate( logicalJoin(