[feature](Nereids): pushdown complex project through inner/outer Join. (#17365)

This commit is contained in:
jakevin
2023-03-08 12:00:56 +08:00
committed by GitHub
parent 778acb3c5b
commit 2b6133f4d0
6 changed files with 313 additions and 29 deletions

View File

@ -235,6 +235,7 @@ public enum RuleType {
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),

View File

@ -38,15 +38,6 @@ import java.util.stream.Stream;
* Common
*/
class JoinReorderUtils {
/**
* check project Expression Input Slot just contains one slot, like:
* - one SlotReference like a.id
* - Input Slot size == 1, like abs(a.id) + 1
*/
static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return project.getProjects().stream().allMatch(expr -> expr.getInputSlotExprIds().size() == 1);
}
static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
return project.getProjects().stream().allMatch(expr -> expr instanceof Slot);
}
@ -78,6 +69,13 @@ class JoinReorderUtils {
return new LogicalProject<>(projectExprs, plan);
}
public static Plan projectOrSelfInOrder(List<NamedExpression> projectExprs, Plan plan) {
if (projectExprs.isEmpty() || projectExprs.equals(plan.getOutput())) {
return plan;
}
return new LogicalProject<>(projectExprs, plan);
}
/**
* replace JoinConjuncts by using slots map.
*/
@ -111,4 +109,11 @@ class JoinReorderUtils {
}
});
}
public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? extends Plan> join, boolean left) {
Set<Slot> childSlots = left ? join.left().getOutputSet() : join.right().getOutputSet();
return join.getConditionSlot().stream()
.filter(childSlots::contains)
.collect(Collectors.toSet());
}
}

View File

@ -0,0 +1,104 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for pushdown project through inner/outer join
*/
public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory {
public static final PushdownProjectThroughInnerJoin INSTANCE = new PushdownProjectThroughInnerJoin();
/*
* Project Join
* | ──► / \
* Join Project Project
* / \ | |
* A B A B
*/
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.child().getJoinType().isInnerJoin())
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet();
Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet();
// reject hyper edge in Project.
if (!project.getProjects().stream().allMatch(expr -> {
Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
return aOutputExprIdSet.containsAll(inputSlotExprIds)
|| bOutputExprIdSet.containsAll(inputSlotExprIds);
})) {
return null;
}
Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(project.getProjects(),
join.left());
List<NamedExpression> aProjects = map.get(true);
List<NamedExpression> bProjects = map.get(false);
boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot));
boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot));
// due to JoinCommute, we don't need to consider just right contains.
if (!leftContains) {
return null;
}
Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects);
Set<Slot> aConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, true);
Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add);
Plan newLeft = JoinReorderUtils.projectOrSelf(newAProject.build(), join.left());
if (!rightContains) {
Plan newJoin = join.withChildren(newLeft, join.right());
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}
Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects);
Set<Slot> bConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, false);
Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add);
Plan newRight = JoinReorderUtils.projectOrSelf(newBProject.build(), join.right());
Plan newJoin = join.withChildren(newLeft, newRight);
return JoinReorderUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
}
}

View File

@ -49,27 +49,23 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
.when(JoinReorderUtils::isOneSlotProject)
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(JoinReorderUtils::isAllSlotProject)
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> aOutputExprIdSet = join.left().getOutputSet();
Set<Slot> conditionLeftSlots = join.getConditionSlot().stream()
.filter(aOutputExprIdSet::contains)
.collect(Collectors.toSet());
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
// Just pushdown project with non-column expr like (t.id + 1)
.whenNot(JoinReorderUtils::isAllSlotProject)
.whenNot(project -> project.child().hasJoinHint())
.then(project -> {
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
Set<Slot> conditionLeftSlots = JoinReorderUtils.joinChildConditionSlots(join, true);
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
Set<Slot> projectUsedSlots = project.getProjects().stream()
.map(NamedExpression::toSlot).collect(Collectors.toSet());
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot))
.forEach(newProject::add);
Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildren(newLeft, join.right());
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
Set<Slot> projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot)
.collect(Collectors.toSet());
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add);
Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left());
Plan newJoin = join.withChildren(newLeft, join.right());
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
}
List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {

View File

@ -0,0 +1,151 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Test
public void pushBothSide() {
// project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
scan1.getOutput().get(1),
new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"),
scan2.getOutput().get(1)
);
// complex projection contain ti.id, which isn't in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalJoin(
logicalProject().when(project -> project.getProjects().size() == 2),
logicalProject().when(project -> project.getProjects().size() == 2)
)
);
}
@Test
public void pushdownProjectInCondition() {
// project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
scan1.getOutput().get(1),
new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"),
scan2.getOutput().get(1)
);
// complex projection contain ti.id, which is in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalProject(
logicalJoin(
logicalProject().when(project -> project.getProjects().size() == 3),
logicalProject().when(project -> project.getProjects().size() == 3)
)
)
);
}
@Test
void pushComplexProject() {
// project (t1.id + t1.name) as complex1, (t2.id + t2.name) as complex2
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex1"),
new Alias(new Add(scan2.getOutput().get(0), scan2.getOutput().get(1)), "complex2")
);
// complex projection contain ti.id, which is in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalProject(
logicalJoin(
logicalProject()
.when(project ->
project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`")
&& project.getProjects().get(1).toSql().equals("id")),
logicalProject()
.when(project ->
project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`")
&& project.getProjects().get(1).toSql().equals("id"))
)
).when(project -> project.getProjects().get(0).toSql().equals("complex1")
&& project.getProjects().get(1).toSql().equals("complex2")
)
);
}
@Test
void rejectHyperEdgeProject() {
// project (t1.id + t2.id) as alias
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), scan2.getOutput().get(0)), "alias")
);
// complex projection contain ti.id, which is in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
}
}

View File

@ -95,4 +95,31 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported {
).when(project -> project.getProjects().size() == 2)
);
}
@Test
void pushComplexProject() {
// project (t1.id + t1.name) as complex
List<NamedExpression> projectExprs = ImmutableList.of(
new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex"));
// complex projection contain ti.id, which is in Join Condition
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.projectExprs(projectExprs)
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
.printlnOrigin()
.printlnExploration()
.matchesExploration(
logicalProject(
leftSemiLogicalJoin(
logicalProject()
.when(project -> project.getProjects().get(0).toSql().equals("(id + name) AS `complex`")
&& project.getProjects().get(1).toSql().equals("id")),
logicalOlapScan()
)
).when(project -> project.getProjects().get(0).toSql().equals("complex"))
);
}
}