[feature](Nereids): pushdown complex project through inner/outer Join. (#17365)
This commit is contained in:
@ -235,6 +235,7 @@ public enum RuleType {
|
||||
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
|
||||
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
|
||||
|
||||
// implementation rules
|
||||
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
|
||||
|
||||
@ -38,15 +38,6 @@ import java.util.stream.Stream;
|
||||
* Common
|
||||
*/
|
||||
class JoinReorderUtils {
|
||||
/**
|
||||
* check project Expression Input Slot just contains one slot, like:
|
||||
* - one SlotReference like a.id
|
||||
* - Input Slot size == 1, like abs(a.id) + 1
|
||||
*/
|
||||
static boolean isOneSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
|
||||
return project.getProjects().stream().allMatch(expr -> expr.getInputSlotExprIds().size() == 1);
|
||||
}
|
||||
|
||||
static boolean isAllSlotProject(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) {
|
||||
return project.getProjects().stream().allMatch(expr -> expr instanceof Slot);
|
||||
}
|
||||
@ -78,6 +69,13 @@ class JoinReorderUtils {
|
||||
return new LogicalProject<>(projectExprs, plan);
|
||||
}
|
||||
|
||||
public static Plan projectOrSelfInOrder(List<NamedExpression> projectExprs, Plan plan) {
|
||||
if (projectExprs.isEmpty() || projectExprs.equals(plan.getOutput())) {
|
||||
return plan;
|
||||
}
|
||||
return new LogicalProject<>(projectExprs, plan);
|
||||
}
|
||||
|
||||
/**
|
||||
* replace JoinConjuncts by using slots map.
|
||||
*/
|
||||
@ -111,4 +109,11 @@ class JoinReorderUtils {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? extends Plan> join, boolean left) {
|
||||
Set<Slot> childSlots = left ? join.left().getOutputSet() : join.right().getOutputSet();
|
||||
return join.getConditionSlot().stream()
|
||||
.filter(childSlots::contains)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* rule for pushdown project through inner/outer join
|
||||
*/
|
||||
public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory {
|
||||
public static final PushdownProjectThroughInnerJoin INSTANCE = new PushdownProjectThroughInnerJoin();
|
||||
|
||||
/*
|
||||
* Project Join
|
||||
* | ──► / \
|
||||
* Join Project Project
|
||||
* / \ | |
|
||||
* A B A B
|
||||
*/
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(logicalJoin())
|
||||
.when(project -> project.child().getJoinType().isInnerJoin())
|
||||
.whenNot(project -> project.child().hasJoinHint())
|
||||
.then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
Set<ExprId> aOutputExprIdSet = join.left().getOutputExprIdSet();
|
||||
Set<ExprId> bOutputExprIdSet = join.right().getOutputExprIdSet();
|
||||
|
||||
// reject hyper edge in Project.
|
||||
if (!project.getProjects().stream().allMatch(expr -> {
|
||||
Set<ExprId> inputSlotExprIds = expr.getInputSlotExprIds();
|
||||
return aOutputExprIdSet.containsAll(inputSlotExprIds)
|
||||
|| bOutputExprIdSet.containsAll(inputSlotExprIds);
|
||||
})) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(project.getProjects(),
|
||||
join.left());
|
||||
List<NamedExpression> aProjects = map.get(true);
|
||||
List<NamedExpression> bProjects = map.get(false);
|
||||
|
||||
boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot));
|
||||
boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot));
|
||||
// due to JoinCommute, we don't need to consider just right contains.
|
||||
if (!leftContains) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Builder<NamedExpression> newAProject = ImmutableList.<NamedExpression>builder().addAll(aProjects);
|
||||
Set<Slot> aConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, true);
|
||||
Set<Slot> aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
|
||||
aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add);
|
||||
Plan newLeft = JoinReorderUtils.projectOrSelf(newAProject.build(), join.left());
|
||||
|
||||
if (!rightContains) {
|
||||
Plan newJoin = join.withChildren(newLeft, join.right());
|
||||
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
|
||||
}
|
||||
|
||||
Builder<NamedExpression> newBProject = ImmutableList.<NamedExpression>builder().addAll(bProjects);
|
||||
Set<Slot> bConditionSlots = JoinReorderUtils.joinChildConditionSlots(join, false);
|
||||
Set<Slot> bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet());
|
||||
bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add);
|
||||
Plan newRight = JoinReorderUtils.projectOrSelf(newBProject.build(), join.right());
|
||||
|
||||
Plan newJoin = join.withChildren(newLeft, newRight);
|
||||
return JoinReorderUtils.projectOrSelfInOrder(new ArrayList<>(project.getOutput()), newJoin);
|
||||
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN);
|
||||
}
|
||||
}
|
||||
@ -49,27 +49,23 @@ public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(logicalJoin())
|
||||
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
|
||||
.when(JoinReorderUtils::isOneSlotProject)
|
||||
// Just pushdown project with non-column expr like (t.id + 1)
|
||||
.whenNot(JoinReorderUtils::isAllSlotProject)
|
||||
.whenNot(project -> project.child().hasJoinHint())
|
||||
.then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
Set<Slot> aOutputExprIdSet = join.left().getOutputSet();
|
||||
Set<Slot> conditionLeftSlots = join.getConditionSlot().stream()
|
||||
.filter(aOutputExprIdSet::contains)
|
||||
.collect(Collectors.toSet());
|
||||
.when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin())
|
||||
// Just pushdown project with non-column expr like (t.id + 1)
|
||||
.whenNot(JoinReorderUtils::isAllSlotProject)
|
||||
.whenNot(project -> project.child().hasJoinHint())
|
||||
.then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
Set<Slot> conditionLeftSlots = JoinReorderUtils.joinChildConditionSlots(join, true);
|
||||
|
||||
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
|
||||
Set<Slot> projectUsedSlots = project.getProjects().stream()
|
||||
.map(NamedExpression::toSlot).collect(Collectors.toSet());
|
||||
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot))
|
||||
.forEach(newProject::add);
|
||||
Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left());
|
||||
Plan newJoin = join.withChildren(newLeft, join.right());
|
||||
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
|
||||
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
|
||||
List<NamedExpression> newProject = new ArrayList<>(project.getProjects());
|
||||
Set<Slot> projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot)
|
||||
.collect(Collectors.toSet());
|
||||
conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add);
|
||||
Plan newLeft = JoinReorderUtils.projectOrSelf(newProject, join.left());
|
||||
|
||||
Plan newJoin = join.withChildren(newLeft, join.right());
|
||||
return JoinReorderUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin);
|
||||
}).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN);
|
||||
}
|
||||
|
||||
List<NamedExpression> sort(List<NamedExpression> projects, Plan sortPlan) {
|
||||
|
||||
@ -0,0 +1,151 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported {
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
|
||||
@Test
|
||||
public void pushBothSide() {
|
||||
// project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
|
||||
List<NamedExpression> projectExprs = ImmutableList.of(
|
||||
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
|
||||
scan1.getOutput().get(1),
|
||||
new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"),
|
||||
scan2.getOutput().get(1)
|
||||
);
|
||||
// complex projection contain ti.id, which isn't in Join Condition
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1))
|
||||
.projectExprs(projectExprs)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalJoin(
|
||||
logicalProject().when(project -> project.getProjects().size() == 2),
|
||||
logicalProject().when(project -> project.getProjects().size() == 2)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pushdownProjectInCondition() {
|
||||
// project (t1.id + 1) as alias, t1.name, (t2.id + 1) as alias, t2.name
|
||||
List<NamedExpression> projectExprs = ImmutableList.of(
|
||||
new Alias(new Add(scan1.getOutput().get(0), Literal.of(1)), "alias"),
|
||||
scan1.getOutput().get(1),
|
||||
new Alias(new Add(scan2.getOutput().get(0), Literal.of(1)), "alias"),
|
||||
scan2.getOutput().get(1)
|
||||
);
|
||||
// complex projection contain ti.id, which is in Join Condition
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.projectExprs(projectExprs)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject().when(project -> project.getProjects().size() == 3),
|
||||
logicalProject().when(project -> project.getProjects().size() == 3)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void pushComplexProject() {
|
||||
// project (t1.id + t1.name) as complex1, (t2.id + t2.name) as complex2
|
||||
List<NamedExpression> projectExprs = ImmutableList.of(
|
||||
new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex1"),
|
||||
new Alias(new Add(scan2.getOutput().get(0), scan2.getOutput().get(1)), "complex2")
|
||||
);
|
||||
// complex projection contain ti.id, which is in Join Condition
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.projectExprs(projectExprs)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject()
|
||||
.when(project ->
|
||||
project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`")
|
||||
&& project.getProjects().get(1).toSql().equals("id")),
|
||||
logicalProject()
|
||||
.when(project ->
|
||||
project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`")
|
||||
&& project.getProjects().get(1).toSql().equals("id"))
|
||||
)
|
||||
).when(project -> project.getProjects().get(0).toSql().equals("complex1")
|
||||
&& project.getProjects().get(1).toSql().equals("complex2")
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void rejectHyperEdgeProject() {
|
||||
// project (t1.id + t2.id) as alias
|
||||
List<NamedExpression> projectExprs = ImmutableList.of(
|
||||
new Alias(new Add(scan1.getOutput().get(0), scan2.getOutput().get(0)), "alias")
|
||||
);
|
||||
// complex projection contain ti.id, which is in Join Condition
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.projectExprs(projectExprs)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build())
|
||||
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
|
||||
}
|
||||
}
|
||||
@ -95,4 +95,31 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported {
|
||||
).when(project -> project.getProjects().size() == 2)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void pushComplexProject() {
|
||||
// project (t1.id + t1.name) as complex
|
||||
List<NamedExpression> projectExprs = ImmutableList.of(
|
||||
new Alias(new Add(scan1.getOutput().get(0), scan1.getOutput().get(1)), "complex"));
|
||||
// complex projection contain ti.id, which is in Join Condition
|
||||
LogicalPlan plan = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.projectExprs(projectExprs)
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build())
|
||||
.printlnOrigin()
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
leftSemiLogicalJoin(
|
||||
logicalProject()
|
||||
.when(project -> project.getProjects().get(0).toSql().equals("(id + name) AS `complex`")
|
||||
&& project.getProjects().get(1).toSql().equals("id")),
|
||||
logicalOlapScan()
|
||||
)
|
||||
).when(project -> project.getProjects().get(0).toSql().equals("complex"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user