[feature](Nereids): pushdown complex project through left semi/anti Join. (#17186)

This commit is contained in:
jakevin
2023-03-02 21:41:08 +08:00
committed by GitHub
parent a1399043fe
commit 93d2d461b4
14 changed files with 286 additions and 203 deletions

View File

@ -229,6 +229,8 @@ public enum RuleType {
LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
@ -267,7 +269,6 @@ public enum RuleType {
LOGICAL_WINDOW_TO_PHYSICAL_WINDOW_RULE(RuleTypeClass.IMPLEMENTATION),
IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
// sentinel, use to count rules
SENTINEL(RuleTypeClass.SENTINEL),
;

View File

@ -36,6 +36,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Rule for change inner join LAsscom (associative and commutive).
@ -57,7 +58,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
.when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
@ -93,18 +94,22 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
List<Expression> newTopOtherConjuncts = splitOtherConjuncts.get(true);
List<Expression> newBottomOtherConjuncts = splitOtherConjuncts.get(false);
JoinReorderHelper helper = new JoinReorderHelper(newTopHashConjuncts, newTopOtherConjuncts,
newBottomHashConjuncts, newBottomOtherConjuncts, projects, aProjects, bProjects);
// Add all slots used by OnCondition when projects not empty.
helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects),
c.getOutputExprIdSet());
Set<ExprId> aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects);
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet()));
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);
aProjects.addAll(cOutputSet);
/* ********** new Plan ********** */
LogicalJoin<Plan, Plan> newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts,
helper.newBottomOtherConjuncts, a, c);
LogicalJoin<Plan, Plan> newBottomJoin = topJoin.withConjunctsChildren(newBottomHashConjuncts,
newBottomOtherConjuncts, a, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);
@ -112,8 +117,8 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin);
Plan right = JoinReorderUtils.projectOrSelf(bProjects, b);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts,
helper.newTopOtherConjuncts, left, right);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(newTopHashConjuncts,
newTopOtherConjuncts, left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
newTopJoin.getJoinReorderContext().setHasLAsscom(true);

View File

@ -1,99 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Helper class for three left deep tree ( original plan tree is (a b) c )create new join.
*/
public class JoinReorderHelper {
public List<Expression> newTopHashConjuncts;
public List<Expression> newTopOtherConjuncts;
public List<Expression> newBottomHashConjuncts;
public List<Expression> newBottomOtherConjuncts;
public List<NamedExpression> oldProjects;
public List<NamedExpression> newLeftProjects;
public List<NamedExpression> newRightProjects;
/**
* Constructor.
*/
public JoinReorderHelper(List<Expression> newTopHashConjuncts, List<Expression> newTopOtherConjuncts,
List<Expression> newBottomHashConjuncts, List<Expression> newBottomOtherConjuncts,
List<NamedExpression> oldProjects, List<NamedExpression> newLeftProjects,
List<NamedExpression> newRightProjects) {
this.newTopHashConjuncts = newTopHashConjuncts;
this.newTopOtherConjuncts = newTopOtherConjuncts;
this.newBottomHashConjuncts = newBottomHashConjuncts;
this.newBottomOtherConjuncts = newBottomOtherConjuncts;
this.oldProjects = oldProjects;
this.newLeftProjects = newLeftProjects;
this.newRightProjects = newRightProjects;
replaceConjuncts(oldProjects);
}
private void replaceConjuncts(List<NamedExpression> projects) {
Map<ExprId, Slot> inputToOutput = new HashMap<>();
Map<ExprId, Slot> outputToInput = new HashMap<>();
for (NamedExpression expr : projects) {
Slot outputSlot = expr.toSlot();
Set<Slot> usedSlots = expr.getInputSlots();
Preconditions.checkState(usedSlots.size() == 1);
Slot inputSlot = usedSlots.iterator().next();
inputToOutput.put(inputSlot.getExprId(), outputSlot);
outputToInput.put(outputSlot.getExprId(), inputSlot);
}
newBottomHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput);
newTopHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput);
newBottomOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput);
newTopOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput);
}
/**
* Add all slots used by OnCondition when projects not empty.
* @param cOutputExprIdSet we want to get abOnUsedSlots, we need filter cOutputExprIdSet.
*/
public void addSlotsUsedByOn(Set<ExprId> splitIds, Set<ExprId> cOutputExprIdSet) {
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
newTopHashConjuncts.stream(),
newTopOtherConjuncts.stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.filter(slot -> !cOutputExprIdSet.contains(slot.getExprId()))
.collect(Collectors.partitioningBy(slot -> splitIds.contains(slot.getExprId()), Collectors.toSet()));
Set<Slot> aUsedSlots = abOnUsedSlots.get(true);
Set<Slot> bUsedSlots = abOnUsedSlots.get(false);
JoinReorderUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects);
JoinReorderUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects);
}
}

View File

@ -17,7 +17,6 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@ -40,23 +39,16 @@ import java.util.stream.Stream;
*/
class JoinReorderUtils {
/**
* check project inside Join to prevent matching some pattern.
* just allow projection is slot or Alias(slot) to prevent reorder when:
* - output of project function is in condition, A join (project [abs(B.id), ..] B join C on ..) on abs(B.id)=A.id.
* - hyper edge in projection. project A.id + B.id A join B on .. (this project will prevent join reorder).
* check project Expression Input Slot just contains one slot, like:
* - one SlotReference like a.id
* - Input Slot size == 1, like abs(a.id) + 1
*/
static boolean checkProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
List<NamedExpression> exprs = project.getProjects();
// must be slot or Alias(slot)
return exprs.stream().allMatch(expr -> {
if (expr instanceof Slot) {
return true;
}
if (expr instanceof Alias) {
return ((Alias) expr).child() instanceof Slot;
}
return false;
});
static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return project.getProjects().stream().allMatch(expr -> expr.getInputSlotExprIds().size() == 1);
}
static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return project.getProjects().stream().allMatch(expr -> expr instanceof Slot);
}
static Map<Boolean, List<NamedExpression>> splitProjection(List<NamedExpression> projects, Plan splitChild) {

View File

@ -23,13 +23,13 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -57,6 +57,7 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
@ -81,26 +82,26 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
return null;
}
// topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition
JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(),
topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects);
// Add all slots used by OnCondition when projects not empty.
helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects),
Collections.EMPTY_SET);
Set<ExprId> aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects);
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet()));
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);
bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet()));
/* ********** new Plan ********** */
LogicalJoin<Plan, Plan> newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts,
helper.newBottomOtherConjuncts, b, c);
LogicalJoin newBottomJoin = (LogicalJoin) topJoin.withChildren(b, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
Plan left = JoinReorderUtils.projectOrSelf(aProjects, a);
Plan right = JoinReorderUtils.projectOrSelf(bProjects, newBottomJoin);
LogicalJoin<Plan, Plan> newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts,
helper.newTopOtherConjuncts, left, right);
LogicalJoin newTopJoin = (LogicalJoin) bottomJoin.withChildren(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
OuterJoinAssoc.setReorderContext(newTopJoin, newBottomJoin);

View File

@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -60,7 +59,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
@ -85,20 +84,21 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
return null;
}
// topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition
JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(),
topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects);
// Add all slots used by OnCondition when projects not empty.
helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects),
Collections.EMPTY_SET);
Set<ExprId> aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, aProjects);
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet()));
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
JoinReorderUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);
aProjects.addAll(forceToNullable(c.getOutputSet()));
/* ********** new Plan ********** */
LogicalJoin newBottomJoin = topJoin.withConjunctsChildren(helper.newBottomHashConjuncts,
helper.newBottomOtherConjuncts, a, c);
LogicalJoin newBottomJoin = (LogicalJoin) topJoin.withChildren(a, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);
@ -106,8 +106,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin);
Plan right = JoinReorderUtils.projectOrSelf(bProjects, b);
LogicalJoin newTopJoin = bottomJoin.withConjunctsChildren(helper.newTopHashConjuncts,
helper.newTopOtherConjuncts, left, right);
LogicalJoin newTopJoin = (LogicalJoin) bottomJoin.withChildren(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
newTopJoin.getJoinReorderContext().setHasLAsscom(true);

View File

@ -0,0 +1,88 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for pushdown project through left-semi/anti join
*/
public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
public static final PushdownProjectThroughSemiJoin INSTANCE = new PushdownProjectThroughSemiJoin();
/*
* Project Join
* | ──► / \
* Join Project B
* / \ |
* A B A
*/
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
.when(JoinReorderUtils::isOneSlotProject)
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(JoinReorderUtils::isAllSlotProject)
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> aOutputExprIdSet = join.left().getOutputSet();
Set<Slot> conditionLeftSlots = join.getConditionSlot().stream()
.filter(aOutputExprIdSet::contains)
.collect(Collectors.toSet());
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
Set<Slot> projectUsedSlots = project.getProjects().stream()
.map(NamedExpression::toSlot).collect(Collectors.toSet());
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot))
.forEach(newProject::add);
Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildren(newLeft, join.right());
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
}
List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {
List<ExprId> orderExprIds = sortPlan.getOutput().stream().map(Slot::getExprId).collect(Collectors.toList());
// map { project input slot expr id -> project output expr }
Map<ExprId, NamedExpression> map = projects.stream()
.collect(Collectors.toMap(expr -> expr.getInputSlots().iterator().next().getExprId(), expr -> expr));
List<NamedExpression> newProjects = new ArrayList<>();
for (ExprId exprId : orderExprIds) {
if (map.containsKey(exprId)) {
newProjects.add(map.get(exprId));
}
}
return newProjects;
}
}

View File

@ -67,7 +67,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child();

View File

@ -54,7 +54,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory
.when(this::typeChecker)
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.when(join -> JoinReorderUtils.checkProject(join.left()))
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
.then(topSemi -> {
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
LogicalProject abProject = topSemi.left();

View File

@ -35,11 +35,14 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
/**
* Logical join plan.
@ -115,6 +118,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
return hashJoinConjuncts;
}
public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
}
public Optional<Expression> getOnClauseCondition() {
return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts);
}

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownAliasThroughJoin;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@ -38,6 +39,7 @@ import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -97,49 +99,27 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.printlnTree()
.applyTopDown(new PushdownAliasThroughJoin())
.applyExploration(InnerJoinLAsscomProject.INSTANCE.build())
.printlnExploration()
.matchesExploration(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalProject(logicalOlapScan().when(
scan -> scan.getTable().getName().equals("t1"))),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
)
).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")))
).when(project -> project.getProjects().size() == 1)
)
);
}
@Test
void testAliasTopMultiHashJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id=t2.id
.alias(ImmutableList.of(0, 2), ImmutableList.of("t1.id", "t2.id"))
// t1.id=t3.id t2.id = t3.id
.join(scan3, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(1, 0)))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(InnerJoinLAsscomProject.INSTANCE.build())
.printlnOrigin()
.matchesExploration(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
).when(join -> join.getHashJoinConjuncts().size() == 1)
).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
).when(project -> project.getProjects().size() == 1)
).when(join -> join.getHashJoinConjuncts().size() == 2)
);
}
@Test
public void testHashAndOther() {
// Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name);
@ -164,16 +144,13 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.printlnTree()
.applyTopDown(new PushdownAliasThroughJoin())
.applyExploration(InnerJoinLAsscomProject.INSTANCE.build())
.printlnExploration()
.matchesExploration(
innerLogicalJoin(
logicalProject(
innerLogicalJoin().when(
join -> Objects.equals(join.getHashJoinConjuncts().toString(),
"[(id#0 = id#8)]")
&& Objects.equals(join.getOtherJoinConjuncts().toString(),
"[(name#1 > name#9)]"))),
innerLogicalJoin().when(join -> join.getHashJoinConjuncts().size() == 1
&& join.getOtherJoinConjuncts().size() == 1),
group()
).when(join -> Objects.equals(join.getHashJoinConjuncts().toString(),
"[(t2.id#6 = id#8), (t1.id#4 = t2.id#6)]")
@ -201,7 +178,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
* </pre>
*/
@Test
@Disabled
public void testComplexConjuncts() {
// TODO: move to sql-test
// Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name);
List<Expression> bottomHashJoinConjunct = ImmutableList.of(
new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)));
@ -262,7 +241,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
* </pre>
*/
@Test
@Disabled
public void testComplexConjunctsWithSubString() {
// TODO: move to sql-test
// Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name);
List<Expression> bottomHashJoinConjunct = ImmutableList.of(
new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)));
@ -324,7 +305,9 @@ class InnerJoinLAsscomProjectTest implements MemoPatternMatchSupported {
* </pre>
*/
@Test
@Disabled
public void testComplexConjunctsAndAlias() {
// TODO: move to sql-test
// Alias (scan1 join scan2 on scan1.id=scan2.id and scan1.name>scan2.name);
List<Expression> bottomHashJoinConjunct = ImmutableList.of(
new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)));

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownAliasThroughJoin;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
@ -76,21 +77,22 @@ class OuterJoinLAsscomProjectTest implements MemoPatternMatchSupported {
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.printlnOrigin()
.applyTopDown(new PushdownAliasThroughJoin())
.printlnTree()
.applyExploration(OuterJoinLAsscomProject.INSTANCE.build())
.printlnExploration()
.matchesExploration(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
)
).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name
logicalProject(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
).when(project -> project.getProjects().size() == 1)
)
logicalJoin(
logicalProject(
logicalJoin(
logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1"))),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
)
).when(project -> project.getProjects().size() == 3), // t1.id Add t3.id, t3.name
logicalProject(
logicalProject(logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")))
).when(project -> project.getProjects().size() == 1)
)
);
}

View File

@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
import java.util.List;
class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Test
public void pushdownProject() {
// project (t1.id + 1) as alias, t1.name
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
scan1.getOutput().get(1)
);
// complex projection contain ti.id, which isn't in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 1))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
leftSemiLogicalJoin(
logicalProject(
logicalOlapScan()
).when(project -> project.getProjects().size() == 2),
logicalOlapScan()
)
);
}
@Test
public void pushdownProjectInCondition() {
// project (t1.id + 1) as alias, t1.name
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
scan1.getOutput().get(1)
);
// complex projection contain ti.id, which is in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalProject(
leftSemiLogicalJoin(
logicalProject(
logicalOlapScan()
).when(project -> project.getProjects().size() == 3),
logicalOlapScan()
)
).when(project -> project.getProjects().size() == 2)
);
}
}

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@ -30,10 +31,13 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class EliminateDedupJoinConditionTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Test
void testEliminate() {
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(0, 0)))
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(0, 0)))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)