[feature](Nereids): pushdown Alias through Join. (#17150)
This commit is contained in:
@ -53,6 +53,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeGenerates;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownAliasThroughJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownExpressionsInHashCondition;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughAggregation;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownFilterThroughJoin;
|
||||
@ -104,6 +105,7 @@ public class RuleSet {
|
||||
new PushdownFilterThroughRepeat(),
|
||||
new PushdownFilterThroughSetOperation(),
|
||||
new PushdownProjectThroughLimit(),
|
||||
new PushdownAliasThroughJoin(),
|
||||
new EliminateOuterJoin(),
|
||||
new MergeProjects(),
|
||||
new MergeFilters(),
|
||||
|
||||
@ -126,6 +126,7 @@ public enum RuleType {
|
||||
PUSHDOWN_FILTER_THROUGH_PROJECT(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_PROJECT_THROUGH_LIMIT(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_ALIAS_THROUGH_JOIN(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_FILTER_THROUGH_SET_OPERATION(RuleTypeClass.REWRITE),
|
||||
// column prune rules,
|
||||
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -0,0 +1,126 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Pushdown Alias (inside must be Slot) through Join.
|
||||
*/
|
||||
public class PushdownAliasThroughJoin extends OneRewriteRuleFactory {
|
||||
private boolean isAllSlotOrAliasSlot(LogicalProject<? extends Plan> project) {
|
||||
return project.getProjects().stream().allMatch(expr -> {
|
||||
if (expr instanceof Slot) {
|
||||
return true;
|
||||
}
|
||||
if (expr instanceof Alias) {
|
||||
return ((Alias) expr).child() instanceof Slot;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalProject(logicalJoin())
|
||||
.when(this::isAllSlotOrAliasSlot)
|
||||
.then(project -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> join = project.child();
|
||||
// aliasMap { Slot -> Alias<Slot> }
|
||||
Map<Expression, NamedExpression> aliasMap = project.getProjects().stream()
|
||||
.filter(expr -> expr instanceof Alias && ((Alias) expr).child() instanceof Slot)
|
||||
.map(expr -> (Alias) expr).collect(Collectors.toMap(UnaryNode::child, expr -> expr));
|
||||
if (aliasMap.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
List<NamedExpression> newProjects = project.getProjects().stream().map(NamedExpression::toSlot)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<Slot> leftOutput = join.left().getOutput();
|
||||
List<NamedExpression> leftProjects = leftOutput.stream().map(slot -> {
|
||||
NamedExpression alias = aliasMap.get(slot);
|
||||
if (alias != null) {
|
||||
return alias;
|
||||
}
|
||||
return slot;
|
||||
}).collect(Collectors.toList());
|
||||
List<Slot> rightOutput = join.right().getOutput();
|
||||
List<NamedExpression> rightProjects = rightOutput.stream().map(slot -> {
|
||||
NamedExpression alias = aliasMap.get(slot);
|
||||
if (alias != null) {
|
||||
return alias;
|
||||
}
|
||||
return slot;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
Plan left;
|
||||
Plan right;
|
||||
if (leftOutput.equals(leftProjects)) {
|
||||
left = join.left();
|
||||
} else {
|
||||
left = new LogicalProject<>(leftProjects, join.left());
|
||||
}
|
||||
if (rightOutput.equals(rightProjects)) {
|
||||
right = join.right();
|
||||
} else {
|
||||
right = new LogicalProject<>(rightProjects, join.right());
|
||||
}
|
||||
|
||||
// If condition use alias slot, we should replace condition
|
||||
// project a.id as aid -- join a.id = b.id =>
|
||||
// join aid = b.id -- project a.id as aid
|
||||
Map<ExprId, Slot> replaceMap = aliasMap.entrySet().stream().collect(
|
||||
Collectors.toMap(entry -> ((Slot) entry.getKey()).getExprId(),
|
||||
entry -> entry.getValue().toSlot()));
|
||||
|
||||
List<Expression> newHash = replaceJoinConjuncts(join.getHashJoinConjuncts(), replaceMap);
|
||||
List<Expression> newOther = replaceJoinConjuncts(join.getOtherJoinConjuncts(), replaceMap);
|
||||
|
||||
Plan newJoin = join.withConjunctsChildren(newHash, newOther, left, right);
|
||||
return new LogicalProject<>(newProjects, newJoin);
|
||||
}).toRule(RuleType.PUSHDOWN_ALIAS_THROUGH_JOIN);
|
||||
}
|
||||
|
||||
private List<Expression> replaceJoinConjuncts(List<Expression> joinConjuncts, Map<ExprId, Slot> replaceMaps) {
|
||||
return joinConjuncts.stream().map(expr -> expr.rewriteUp(e -> {
|
||||
if (e instanceof Slot && replaceMaps.containsKey(((Slot) e).getExprId())) {
|
||||
return replaceMaps.get(((Slot) e).getExprId());
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
})).collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
}
|
||||
@ -32,8 +32,8 @@ public class RankTest extends TPCHTestBase {
|
||||
@Test
|
||||
void testRank() throws NoSuchFieldException, IllegalAccessException {
|
||||
for (int i = 1; i < 22; i++) {
|
||||
Field field = TPCHUtils.class.getField("Q" + String.valueOf(i));
|
||||
System.out.println("Q" + String.valueOf(i));
|
||||
Field field = TPCHUtils.class.getField("Q" + i);
|
||||
System.out.println("Q" + i);
|
||||
Memo memo = PlanChecker.from(connectContext)
|
||||
.analyze(field.get(null).toString())
|
||||
.rewrite()
|
||||
@ -47,8 +47,8 @@ public class RankTest extends TPCHTestBase {
|
||||
@Test
|
||||
void testUnrank() throws NoSuchFieldException, IllegalAccessException {
|
||||
for (int i = 1; i < 22; i++) {
|
||||
Field field = TPCHUtils.class.getField("Q" + String.valueOf(i));
|
||||
System.out.println("Q" + String.valueOf(i));
|
||||
Field field = TPCHUtils.class.getField("Q" + i);
|
||||
System.out.println("Q" + i);
|
||||
Memo memo = PlanChecker.from(connectContext)
|
||||
.analyze(field.get(null).toString())
|
||||
.rewrite()
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class PushdownAliasThroughJoinTest implements PatternMatchSupported {
|
||||
|
||||
@Test
|
||||
void testPushdown() {
|
||||
// condition don't use alias slot
|
||||
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
|
||||
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.alias(ImmutableList.of(1, 3), ImmutableList.of("1name", "2name"))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownAliasThroughJoin())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject().when(project -> project.getProjects().get(1).toSql().equals("name AS `1name`")),
|
||||
logicalProject().when(project -> project.getProjects().get(1).toSql().equals("name AS `2name`"))
|
||||
)
|
||||
).when(project -> project.getProjects().get(0).toSql().equals("1name") && project.getProjects().get(1).toSql().equals("2name"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testCondition() {
|
||||
// condition use alias slot
|
||||
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
|
||||
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.alias(ImmutableList.of(0, 1, 3), ImmutableList.of("1id", "1name", "2name"))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownAliasThroughJoin())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject().when(
|
||||
project -> project.getProjects().get(0).toSql().equals("id AS `1id`")
|
||||
&& project.getProjects().get(1).toSql().equals("name AS `1name`")),
|
||||
logicalProject().when(
|
||||
project -> project.getProjects().get(1).toSql().equals("name AS `2name`"))
|
||||
).when(join -> join.getHashJoinConjuncts().get(0).toSql().equals("(1id = id)"))
|
||||
).when(project -> project.getProjects().get(0).toSql().equals("1id")
|
||||
&& project.getProjects().get(1).toSql().equals("1name")
|
||||
&& project.getProjects().get(2).toSql().equals("2name"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testJustRightSide() {
|
||||
// condition use alias slot
|
||||
LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.scan1)
|
||||
.join(PlanConstructor.scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.alias(ImmutableList.of(2, 3), ImmutableList.of("2id", "2name"))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyTopDown(new PushdownAliasThroughJoin())
|
||||
.matches(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalOlapScan(),
|
||||
logicalProject()
|
||||
).when(join -> join.getHashJoinConjuncts().get(0).toSql().equals("(id = 2id)"))
|
||||
).when(project -> project.getProjects().get(0).toSql().equals("2id")
|
||||
&& project.getProjects().get(1).toSql().equals("2name"))
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user