diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index d636f83144..e87cd11b94 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -229,6 +229,8 @@ public enum RuleType { LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION), + LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), + PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION), // implementation rules LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION), @@ -267,7 +269,6 @@ public enum RuleType { LOGICAL_WINDOW_TO_PHYSICAL_WINDOW_RULE(RuleTypeClass.IMPLEMENTATION), IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION), - LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), // sentinel, use to count rules SENTINEL(RuleTypeClass.SENTINEL), ; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index bc8656520f..3581bb81c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -36,6 +36,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Rule for change inner join LAsscom (associative and commutive). @@ -57,7 +58,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { return innerLogicalJoin(logicalProject(innerLogicalJoin()), group()) .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderUtils.checkProject(join.left())) + .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ List projects = topJoin.left().getProjects(); @@ -93,18 +94,22 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { List newTopOtherConjuncts = splitOtherConjuncts.get(true); List newBottomOtherConjuncts = splitOtherConjuncts.get(false); - JoinReorderHelper helper = new JoinReorderHelper(newTopHashConjuncts, newTopOtherConjuncts, - newBottomHashConjuncts, newBottomOtherConjuncts, projects, aProjects, bProjects); - // Add all slots used by OnCondition when projects not empty. - helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), - c.getOutputExprIdSet()); + Set aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects); + Map> abOnUsedSlots = Stream.concat( + bottomJoin.getHashJoinConjuncts().stream(), + bottomJoin.getHashJoinConjuncts().stream()) + .flatMap(onExpr -> onExpr.getInputSlots().stream()) + .collect(Collectors.partitioningBy( + slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet())); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); aProjects.addAll(cOutputSet); /* ********** new Plan ********** */ - LogicalJoin newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts, - helper.newBottomOtherConjuncts, a, c); + LogicalJoin newBottomJoin = topJoin.withConjunctsChildren(newBottomHashConjuncts, + newBottomOtherConjuncts, a, c); newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); @@ -112,8 +117,8 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin); Plan right = JoinReorderUtils.projectOrSelf(bProjects, b); - LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts, - helper.newTopOtherConjuncts, left, right); + LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(newTopHashConjuncts, + newTopOtherConjuncts, left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java deleted file mode 100644 index de52786886..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.join; - -import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; - -import com.google.common.base.Preconditions; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * Helper class for three left deep tree ( original plan tree is (a b) c )create new join. - */ -public class JoinReorderHelper { - public List newTopHashConjuncts; - public List newTopOtherConjuncts; - public List newBottomHashConjuncts; - public List newBottomOtherConjuncts; - - public List oldProjects; - public List newLeftProjects; - public List newRightProjects; - - /** - * Constructor. - */ - public JoinReorderHelper(List newTopHashConjuncts, List newTopOtherConjuncts, - List newBottomHashConjuncts, List newBottomOtherConjuncts, - List oldProjects, List newLeftProjects, - List newRightProjects) { - this.newTopHashConjuncts = newTopHashConjuncts; - this.newTopOtherConjuncts = newTopOtherConjuncts; - this.newBottomHashConjuncts = newBottomHashConjuncts; - this.newBottomOtherConjuncts = newBottomOtherConjuncts; - this.oldProjects = oldProjects; - this.newLeftProjects = newLeftProjects; - this.newRightProjects = newRightProjects; - replaceConjuncts(oldProjects); - } - - private void replaceConjuncts(List projects) { - Map inputToOutput = new HashMap<>(); - Map outputToInput = new HashMap<>(); - for (NamedExpression expr : projects) { - Slot outputSlot = expr.toSlot(); - Set usedSlots = expr.getInputSlots(); - Preconditions.checkState(usedSlots.size() == 1); - Slot inputSlot = usedSlots.iterator().next(); - inputToOutput.put(inputSlot.getExprId(), outputSlot); - outputToInput.put(outputSlot.getExprId(), inputSlot); - } - - newBottomHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); - newTopHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); - newBottomOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); - newTopOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); - } - - /** - * Add all slots used by OnCondition when projects not empty. - * @param cOutputExprIdSet we want to get abOnUsedSlots, we need filter cOutputExprIdSet. - */ - public void addSlotsUsedByOn(Set splitIds, Set cOutputExprIdSet) { - Map> abOnUsedSlots = Stream.concat( - newTopHashConjuncts.stream(), - newTopOtherConjuncts.stream()) - .flatMap(onExpr -> onExpr.getInputSlots().stream()) - .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) - .collect(Collectors.partitioningBy(slot -> splitIds.contains(slot.getExprId()), Collectors.toSet())); - Set aUsedSlots = abOnUsedSlots.get(true); - Set bUsedSlots = abOnUsedSlots.get(false); - - JoinReorderUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); - JoinReorderUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java index 645e74c7f3..b723e2e4a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -40,23 +39,16 @@ import java.util.stream.Stream; */ class JoinReorderUtils { /** - * check project inside Join to prevent matching some pattern. - * just allow projection is slot or Alias(slot) to prevent reorder when: - * - output of project function is in condition, A join (project [abs(B.id), ..] B join C on ..) on abs(B.id)=A.id. - * - hyper edge in projection. project A.id + B.id A join B on .. (this project will prevent join reorder). + * check project Expression Input Slot just contains one slot, like: + * - one SlotReference like a.id + * - Input Slot size == 1, like abs(a.id) + 1 */ - static boolean checkProject(LogicalProject> project) { - List exprs = project.getProjects(); - // must be slot or Alias(slot) - return exprs.stream().allMatch(expr -> { - if (expr instanceof Slot) { - return true; - } - if (expr instanceof Alias) { - return ((Alias) expr).child() instanceof Slot; - } - return false; - }); + static boolean isOneSlotProject(LogicalProject> project) { + return project.getProjects().stream().allMatch(expr -> expr.getInputSlotExprIds().size() == 1); + } + + static boolean isAllSlotProject(LogicalProject> project) { + return project.getProjects().stream().allMatch(expr -> expr instanceof Slot); } static Map> splitProjection(List projects, Plan splitChild) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java index c18d4f2383..b7d62584bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java @@ -23,13 +23,13 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.Utils; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -57,6 +57,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet())) + .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ List projects = topJoin.left().getProjects(); @@ -81,26 +82,26 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { return null; } - // topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition - JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(), - topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects); - // Add all slots used by OnCondition when projects not empty. - helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), - Collections.EMPTY_SET); + Set aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects); + Map> abOnUsedSlots = Stream.concat( + bottomJoin.getHashJoinConjuncts().stream(), + bottomJoin.getHashJoinConjuncts().stream()) + .flatMap(onExpr -> onExpr.getInputSlots().stream()) + .collect(Collectors.partitioningBy( + slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet())); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet())); /* ********** new Plan ********** */ - LogicalJoin newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts, - helper.newBottomOtherConjuncts, b, c); + LogicalJoin newBottomJoin = (LogicalJoin) topJoin.withChildren(b, c); newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext()); Plan left = JoinReorderUtils.projectOrSelf(aProjects, a); Plan right = JoinReorderUtils.projectOrSelf(bProjects, newBottomJoin); - LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts, - helper.newTopOtherConjuncts, left, right); + LogicalJoin newTopJoin = (LogicalJoin) bottomJoin.withChildren(left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); OuterJoinAssoc.setReorderContext(newTopJoin, newBottomJoin); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java index 2562e6045a..8756df12ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java @@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.Utils; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -60,7 +59,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { Pair.of(join.left().child().getJoinType(), join.getJoinType()))) .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderUtils.checkProject(join.left())) + .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) .then(topJoin -> { /* ********** init ********** */ List projects = topJoin.left().getProjects(); @@ -85,20 +84,21 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { return null; } - // topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition - JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(), - topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects); - // Add all slots used by OnCondition when projects not empty. - helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), - Collections.EMPTY_SET); + Set aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects); + Map> abOnUsedSlots = Stream.concat( + bottomJoin.getHashJoinConjuncts().stream(), + bottomJoin.getHashJoinConjuncts().stream()) + .flatMap(onExpr -> onExpr.getInputSlots().stream()) + .collect(Collectors.partitioningBy( + slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet())); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects); + JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects); aProjects.addAll(forceToNullable(c.getOutputSet())); /* ********** new Plan ********** */ - LogicalJoin newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts, - helper.newBottomOtherConjuncts, a, c); + LogicalJoin newBottomJoin = (LogicalJoin) topJoin.withChildren(a, c); newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); @@ -106,8 +106,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin); Plan right = JoinReorderUtils.projectOrSelf(bProjects, b); - LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts, - helper.newTopOtherConjuncts, left, right); + LogicalJoin newTopJoin = (LogicalJoin) bottomJoin.withChildren(left, right); newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java new file mode 100644 index 0000000000..57ea9df15d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * rule for pushdown project through left-semi/anti join + */ +public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory { + public static final PushdownProjectThroughSemiJoin INSTANCE = new PushdownProjectThroughSemiJoin(); + + /* + * Project Join + * | ──► / \ + * Join Project B + * / \ | + * A B A + */ + @Override + public Rule build() { + return logicalProject(logicalJoin()) + .when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin()) + .when(JoinReorderUtils::isOneSlotProject) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(JoinReorderUtils::isAllSlotProject) + .whenNot(project -> project.child().hasJoinHint()) + .then(project -> { + LogicalJoin join = project.child(); + Set aOutputExprIdSet = join.left().getOutputSet(); + Set conditionLeftSlots = join.getConditionSlot().stream() + .filter(aOutputExprIdSet::contains) + .collect(Collectors.toSet()); + + List newProject = new ArrayList<>(project.getProjects()); + Set projectUsedSlots = project.getProjects().stream() + .map(NamedExpression::toSlot).collect(Collectors.toSet()); + conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)) + .forEach(newProject::add); + Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left()); + Plan newJoin = join.withChildren(newLeft, join.right()); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN); + } + + List sort(List projects, Plan sortPlan) { + List orderExprIds = sortPlan.getOutput().stream().map(Slot::getExprId).collect(Collectors.toList()); + // map { project input slot expr id -> project output expr } + Map map = projects.stream() + .collect(Collectors.toMap(expr -> expr.getInputSlots().iterator().next().getExprId(), expr -> expr)); + List newProjects = new ArrayList<>(); + for (ExprId exprId : orderExprIds) { + if (map.containsKey(exprId)) { + newProjects.add(map.get(exprId)); + } + } + return newProjects; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index aefba5aace..374fe5758b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -67,7 +67,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto || topJoin.left().child().getJoinType().isRightOuterJoin()))) .whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin()) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderUtils.checkProject(join.left())) + .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) .then(topSemiJoin -> { LogicalProject> project = topSemiJoin.left(); LogicalJoin bottomJoin = project.child(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index 734108c4db..85ac3cc27b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -54,7 +54,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory .when(this::typeChecker) .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderUtils.checkProject(join.left())) + .when(join -> JoinReorderUtils.isAllSlotProject(join.left())) .then(topSemi -> { LogicalJoin bottomSemi = topSemi.left().child(); LogicalProject abProject = topSemi.left(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 9847c0e316..3214d71fb0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -35,11 +35,14 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; /** * Logical join plan. @@ -115,6 +118,12 @@ public class LogicalJoin getConditionSlot() { + return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream()) + .flatMap(expr -> expr.getInputSlots().stream()) + .collect(ImmutableSet.toImmutableSet()); + } + public Optional getOnClauseCondition() { return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java index 2111253168..96ede80a3b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownAliasThroughJoin; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -38,6 +39,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.List; @@ -97,49 +99,27 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() + .applyTopDown(new PushdownAliasThroughJoin()) .applyExploration(InnerJoinLAsscomProject.INSTANCE.build()) + .printlnExploration() .matchesExploration( logicalJoin( logicalProject( logicalJoin( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalProject(logicalOlapScan().when( + scan -> scan.getTable().getName().equals("t1"))), logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) ) ).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name logicalProject( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) + logicalProject( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) ).when(project -> project.getProjects().size() == 1) ) ); } - @Test - void testAliasTopMultiHashJoin() { - LogicalPlan plan = new LogicalPlanBuilder(scan1) - .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id=t2.id - .alias(ImmutableList.of(0, 2), ImmutableList.of("t1.id", "t2.id")) - // t1.id=t3.id t2.id = t3.id - .join(scan3, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(1, 0))) - .build(); - - PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(InnerJoinLAsscomProject.INSTANCE.build()) - .printlnOrigin() - .matchesExploration( - logicalJoin( - logicalProject( - logicalJoin( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) - ).when(join -> join.getHashJoinConjuncts().size() == 1) - ).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name - logicalProject( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) - ).when(project -> project.getProjects().size() == 1) - ).when(join -> join.getHashJoinConjuncts().size() == 2) - ); - } - @Test public void testHashAndOther() { // Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name); @@ -164,16 +144,13 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .printlnTree() + .applyTopDown(new PushdownAliasThroughJoin()) .applyExploration(InnerJoinLAsscomProject.INSTANCE.build()) .printlnExploration() .matchesExploration( innerLogicalJoin( - logicalProject( - innerLogicalJoin().when( - join -> Objects.equals(join.getHashJoinConjuncts().toString(), - "[(id#0 = id#8)]") - && Objects.equals(join.getOtherJoinConjuncts().toString(), - "[(name#1 > name#9)]"))), + innerLogicalJoin().when(join -> join.getHashJoinConjuncts().size() == 1 + && join.getOtherJoinConjuncts().size() == 1), group() ).when(join -> Objects.equals(join.getHashJoinConjuncts().toString(), "[(t2.id#6 = id#8), (t1.id#4 = t2.id#6)]") @@ -201,7 +178,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported { * */ @Test + @Disabled public void testComplexConjuncts() { + // TODO: move to sql-test // Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name); List bottomHashJoinConjunct = ImmutableList.of( new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))); @@ -262,7 +241,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported { * */ @Test + @Disabled public void testComplexConjunctsWithSubString() { + // TODO: move to sql-test // Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name); List bottomHashJoinConjunct = ImmutableList.of( new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))); @@ -324,7 +305,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported { * */ @Test + @Disabled public void testComplexConjunctsAndAlias() { + // TODO: move to sql-test // Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name); List bottomHashJoinConjunct = ImmutableList.of( new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java index 90b8266f25..c94dcb12ca 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.rewrite.logical.PushdownAliasThroughJoin; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -76,21 +77,22 @@ class OuterJoinLAsscomProjectTest implements MemoPatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .printlnOrigin() + .applyTopDown(new PushdownAliasThroughJoin()) + .printlnTree() .applyExploration(OuterJoinLAsscomProject.INSTANCE.build()) .printlnExploration() .matchesExploration( - logicalJoin( - logicalProject( - logicalJoin( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) - ) - ).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name - logicalProject( - logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")) - ).when(project -> project.getProjects().size() == 1) - ) + logicalJoin( + logicalProject( + logicalJoin( + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1"))), + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3")) + ) + ).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name + logicalProject( + logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))) + ).when(project -> project.getProjects().size() == 1) + ) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java new file mode 100644 index 0000000000..0a3b7a04ba --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + + @Test + public void pushdownProject() { + // project (t1.id + 1) as alias, t1.name + List projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"), + scan1.getOutput().get(1) + ); + // complex projection contain ti.id, which isn't in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 1)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + leftSemiLogicalJoin( + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 2), + logicalOlapScan() + ) + ); + } + + @Test + public void pushdownProjectInCondition() { + // project (t1.id + 1) as alias, t1.name + List projectExprs = ImmutableList.of( + new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"), + scan1.getOutput().get(1) + ); + // complex projection contain ti.id, which is in Join Condition + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .projectExprs(projectExprs) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .printlnOrigin() + .printlnExploration() + .matchesExploration( + logicalProject( + leftSemiLogicalJoin( + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 3), + logicalOlapScan() + ) + ).when(project -> project.getProjects().size() == 2) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateDedupJoinConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateDedupJoinConditionTest.java index 4011e79960..2235821f16 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateDedupJoinConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateDedupJoinConditionTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -30,10 +31,13 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; class EliminateDedupJoinConditionTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + @Test void testEliminate() { - LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1) - .join(PlanConstructor.scan2, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(0, 0))) + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(0, 0))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan)