[fix](Nereids): clone the producer plan and put logicalAnchor generated by Or_Expansion above logicalSink (#34771)

* put cte anchor on the root

put logicalAnchor on root

clone plan of cte consumer

* fix unit test
This commit is contained in:
谢健
2024-05-14 15:04:21 +08:00
committed by yiguolei
parent 5ece07ab8c
commit 0deb629d07
3 changed files with 219 additions and 66 deletions

View File

@ -471,7 +471,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(jobs))
),
topic("or expansion",
topDown(new OrExpansion())),
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new)
)

View File

@ -19,10 +19,12 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.rules.rewrite.OrExpansion.OrExpandsionContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
@ -38,8 +40,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
@ -53,6 +58,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@ -62,7 +68,7 @@ import javax.annotation.Nullable;
* => / \
* HJ(cond1) HJ(cond2 and !cond1)
*/
public class OrExpansion extends OneExplorationRuleFactory {
public class OrExpansion extends DefaultPlanRewriter<OrExpandsionContext> implements CustomRewriter {
public static final OrExpansion INSTANCE = new OrExpansion();
public static final ImmutableSet<JoinType> supportJoinType = new ImmutableSet
.Builder<JoinType>()
@ -73,63 +79,101 @@ public class OrExpansion extends OneExplorationRuleFactory {
.build();
@Override
public Rule build() {
return logicalJoin(any(), any()).when(JoinUtils::shouldNestedLoopJoin)
.whenNot(LogicalJoin::isMarkJoin)
.when(join -> supportJoinType.contains(join.getJoinType())
&& ConnectContext.get().getSessionVariable().getEnablePipelineEngine())
.thenApply(ctx -> {
LogicalJoin<? extends Plan, ? extends Plan> join = ctx.root;
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
OrExpandsionContext ctx = new OrExpandsionContext(
jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
plan = plan.accept(this, ctx);
for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i);
plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
}
return plan;
}
//1. Try to split or conditions
Pair<List<Expression>, List<Expression>> hashOtherConditions = splitOrCondition(join);
if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) {
return join;
}
@Override
public Plan visit(Plan plan, OrExpandsionContext ctx) {
List<Plan> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Plan child : plan.children()) {
Plan newChild = child.accept(this, ctx);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? plan.withChildren(newChildren) : plan;
}
//2. Construct CTE with the children
LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.left());
LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.right());
List<Plan> joins = new ArrayList<>();
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, OrExpandsionContext ctx) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, ctx);
if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
return join;
}
if (!(supportJoinType.contains(join.getJoinType())
&& ConnectContext.get().getSessionVariable().getEnablePipelineEngine())) {
return join;
}
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");
// 3. Expand join to hash join with CTE
if (join.getJoinType().isInnerJoin()) {
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer));
} else if (join.getJoinType().isOuterJoin()) {
// left outer join = inner join union left anti join
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer));
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer));
if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
// full outer join = inner join union left anti join union right anti join
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, rightProducer, leftProducer));
}
} else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer));
} else {
throw new RuntimeException("or-expansion is not supported for " + join);
}
//1. Try to split or conditions
Pair<List<Expression>, List<Expression>> hashOtherConditions = splitOrCondition(join);
if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) {
return join;
}
//4. union all joins and construct LogicalCTEAnchor with CTEs
List<List<SlotReference>> childrenOutputs = joins.stream()
.map(j -> j.getOutput().stream()
.map(SlotReference.class::cast)
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
childrenOutputs, ImmutableList.of(), false, joins);
LogicalCTEAnchor<? extends Plan, ? extends Plan> intermediateAnchor = new LogicalCTEAnchor<>(
rightProducer.getCteId(), rightProducer, union);
return new LogicalCTEAnchor<Plan, Plan>(leftProducer.getCteId(), leftProducer, intermediateAnchor);
}).toRule(RuleType.OR_EXPANSION);
//2. Construct CTE with the children
LogicalPlan leftClone = LogicalPlanDeepCopier.INSTANCE
.deepCopy((LogicalPlan) join.left(), new DeepCopierContext());
LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), leftClone);
LogicalPlan rightClone = LogicalPlanDeepCopier.INSTANCE
.deepCopy((LogicalPlan) join.right(), new DeepCopierContext());
LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), rightClone);
Map<Slot, Slot> leftCloneToLeft = new HashMap<>();
for (int i = 0; i < leftClone.getOutput().size(); i++) {
leftCloneToLeft.put(leftClone.getOutput().get(i), (join.left()).getOutput().get(i));
}
Map<Slot, Slot> rightCloneToRight = new HashMap<>();
for (int i = 0; i < rightClone.getOutput().size(); i++) {
rightCloneToRight.put(rightClone.getOutput().get(i), (join.right()).getOutput().get(i));
}
// 3. Expand join to hash join with CTE
List<Plan> joins = new ArrayList<>();
if (join.getJoinType().isInnerJoin()) {
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
} else if (join.getJoinType().isOuterJoin()) {
// left outer join = inner join union left anti join
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
// full outer join = inner join union left anti join union right anti join
joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
join, rightProducer, leftProducer, rightCloneToRight, leftCloneToLeft));
}
} else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
} else {
throw new RuntimeException("or-expansion is not supported for " + join);
}
//4. union all joins and put producers to context
List<List<SlotReference>> childrenOutputs = joins.stream()
.map(j -> j.getOutput().stream()
.map(SlotReference.class::cast)
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
childrenOutputs, ImmutableList.of(), false, joins);
ctx.cteProducerList.add(leftProducer);
ctx.cteProducerList.add(rightProducer);
return union;
}
// try to find a condition that can be split into hash conditions
@ -150,6 +194,18 @@ public class OrExpansion extends OneExplorationRuleFactory {
return null;
}
private Map<Slot, Slot> constructReplaceMap(LogicalCTEConsumer leftConsumer, Map<Slot, Slot> leftCloneToLeft,
LogicalCTEConsumer rightConsumer, Map<Slot, Slot> rightCloneToRight) {
Map<Slot, Slot> replaced = new HashMap<>();
for (Entry<Slot, Slot> entry : leftConsumer.getProducerToConsumerOutputMap().entrySet()) {
replaced.put(leftCloneToLeft.get(entry.getKey()), entry.getValue());
}
for (Entry<Slot, Slot> entry : rightConsumer.getProducerToConsumerOutputMap().entrySet()) {
replaced.put(rightCloneToRight.get(entry.getKey()), entry.getValue());
}
return replaced;
}
// expand Anti Join:
// Left Anti join cond1 or cond2, other Left Anti join cond1 and other
// / \ / \
@ -160,7 +216,8 @@ public class OrExpansion extends OneExplorationRuleFactory {
Pair<List<Expression>, List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> originJoin,
LogicalCTEProducer<? extends Plan> leftProducer,
LogicalCTEProducer<? extends org.apache.doris.nereids.trees.plans.Plan> rightProducer) {
LogicalCTEProducer<? extends org.apache.doris.nereids.trees.plans.Plan> rightProducer,
Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
leftProducer.getCteId(), "", leftProducer);
LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
@ -168,8 +225,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(left);
ctx.putCTEIdToConsumer(right);
Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
List<Expression> newOtherConditions = otherConditions.stream()
@ -191,8 +247,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
LogicalCTEConsumer newRight = new LogicalCTEConsumer(
ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer);
ctx.putCTEIdToConsumer(newRight);
Map<Slot, Slot> newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap());
newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
Map<Slot, Slot> newReplaced = constructReplaceMap(left, leftCloneToLeft, newRight, rightCloneToRight);
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
.collect(Collectors.toList());
@ -224,7 +279,8 @@ public class OrExpansion extends OneExplorationRuleFactory {
private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,
List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> join, LogicalCTEProducer<? extends Plan> leftProducer,
LogicalCTEProducer<? extends Plan> rightProducer) {
LogicalCTEProducer<? extends Plan> rightProducer,
Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
// For null values, equalTo and not equalTo both return false
@ -248,8 +304,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
ctx.putCTEIdToConsumer(right);
//rewrite conjuncts to replace the old slots with CTE slots
Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
List<Expression> hashCond = pair.first.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
@ -283,4 +338,16 @@ public class OrExpansion extends OneExplorationRuleFactory {
}
return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others);
}
class OrExpandsionContext {
List<LogicalCTEProducer<? extends Plan>> cteProducerList;
StatementContext statementContext;
CascadesContext cascadesContext;
public OrExpandsionContext(StatementContext statementContext, CascadesContext cascadesContext) {
this.statementContext = statementContext;
this.cteProducerList = new ArrayList<>();
this.cascadesContext = cascadesContext;
}
}
}

View File

@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
class OrExpansionTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id1 int not null,\n"
+ " id2 int not null\n"
+ ")\n"
+ "DUPLICATE KEY(id1)\n"
+ "DISTRIBUTED BY HASH(id1) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
"CREATE TABLE IF NOT EXISTS t2 (\n"
+ " id1 int not null,\n"
+ " id2 int not null\n"
+ ")\n"
+ "DUPLICATE KEY(id1)\n"
+ "DISTRIBUTED BY HASH(id2) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n"
);
}
@Test
void testOrExpand() {
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
String sql = "select t1.id1 + 1 as id from t1 join t2 on t1.id1 = t2.id1 or t1.id2 = t2.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.printlnTree()
.getPlan();
Assertions.assertTrue(plan instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor);
}
@Test
void testOrExpandCTE() {
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
connectContext.getSessionVariable().inlineCTEReferencedThreshold = 0;
String sql = "with t3 as (select t1.id1 + 1 as id1, t1.id2 + 2 as id2 from t1), "
+ "t4 as (select t2.id1 + 1 as id1, t2.id2 + 2 as id2 from t2) "
+ "select t3.id1 from "
+ "(select id1, id2 from t3 group by id1, id2) t3 "
+ " join "
+ "(select id1, id2 from t4 group by id1, id2) t4 "
+ "on t3.id1 = t4.id1 or t3.id2 = t4.id2";
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.printlnTree()
.getPlan();
Assertions.assertTrue(plan instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1).child(1) instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1).child(1).child(1) instanceof LogicalCTEAnchor);
}
}