[feature](Nereids): add or expansion in CBO(#22465)

This commit is contained in:
谢健
2023-08-03 13:29:33 +08:00
committed by GitHub
parent 85a95e206e
commit 4322fdc96d
5 changed files with 208 additions and 0 deletions

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules;
import org.apache.doris.nereids.rules.exploration.MergeProjectsCBO;
import org.apache.doris.nereids.rules.exploration.OrExpansion;
import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin;
import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoinProject;
import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscom;
@ -98,6 +99,7 @@ public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(new MergeProjectsCBO())
.add(new OrExpansion())
.build();
public static final List<Rule> OTHER_REORDER_RULES = planRuleFactories()
@ -196,6 +198,7 @@ public class RuleSet {
.build();
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.addAll(EXPLORATION_RULES)
.add(JoinCommute.BUSHY.build())
.build();

View File

@ -250,6 +250,7 @@ public enum RuleType {
// exploration rules
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
OR_EXPANSION(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION),

View File

@ -0,0 +1,158 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* https://blogs.oracle.com/optimizer/post/optimizer-transformations-or-expansion
* NLJ (cond1 or cond2) UnionAll
* => / \
* HJ(cond1) HJ(cond2 and !cond1)
*/
public class OrExpansion extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin()
.when(JoinUtils::shouldNestedLoopJoin)
.when(join -> join.getJoinType().isInnerJoin())
.thenApply(ctx -> {
LogicalJoin<? extends Plan, ? extends Plan> join = ctx.root;
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");
List<Expression> disjunctions = null;
List<Expression> otherConditions = Lists.newArrayList(join.getOtherJoinConjuncts());
// We pick the first or condition that can be split to EqualTo expressions.
for (Expression expr : otherConditions) {
Pair<List<Expression>, List<Expression>> pair = expandExpr(expr, join);
// TODO: Now we don't support expand condition with complex expression
if (pair.second.isEmpty() && pair.first.stream()
.noneMatch(e -> !((EqualTo) e).left().isSlot()
&& !((EqualTo) e).right().isSlot())) {
disjunctions = pair.first;
otherConditions.remove(expr);
break;
}
}
// If there is non-EqualTo expression, it means there is nlj child
// Therefore refuse this case
if (disjunctions == null) {
return join;
}
//Construct CTE with the children
LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.left());
LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.right());
// expand join to hash join with CTE
List<Plan> joins = expandJoin(ctx.statementContext, disjunctions, otherConditions, join,
leftProducer,
rightProducer);
LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
ImmutableList.of(),
false, joins);
LogicalCTEAnchor<? extends Plan, ? extends Plan> intermediateAnchor = new LogicalCTEAnchor<>(
rightProducer.getCteId(), rightProducer, union);
return new LogicalCTEAnchor<Plan, Plan>(leftProducer.getCteId(), leftProducer, intermediateAnchor);
}).toRule(RuleType.OR_EXPANSION);
}
// extract disjunctions for this otherExpr and divide them into HashCond and OtherCond
// return hash conditions and other conditions
private Pair<List<Expression>, List<Expression>> expandExpr(Expression otherExpr,
LogicalJoin<? extends Plan, ? extends Plan> join) {
List<Expression> disjunctions = ExpressionUtils.extractDisjunction(otherExpr);
return JoinUtils.extractExpressionForHashTable(join.left().getOutput(), join.right().getOutput(), disjunctions);
}
private List<Plan> expandJoin(StatementContext ctx, List<Expression> disjunctions, List<Expression> otherConditions,
LogicalJoin<? extends Plan, ? extends Plan> join, LogicalCTEProducer<? extends Plan> leftProducer,
LogicalCTEProducer<? extends Plan> rightProducer) {
List<Expression> notExprs = disjunctions.stream().map(Not::new).collect(Collectors.toList());
List<Plan> joins = Lists.newArrayList();
for (int hashCondIdx = 0; hashCondIdx < disjunctions.size(); hashCondIdx++) {
// extract hash conditions and other condition
Pair<List<Expression>, List<Expression>> pair =
extractHashAndOtherConditions(hashCondIdx, disjunctions, notExprs);
pair.second.addAll(otherConditions);
LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getNextRelationId(), leftProducer.getCteId(), "",
leftProducer);
LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getNextRelationId(), rightProducer.getCteId(), "",
rightProducer);
//rewrite conjuncts to replace the old slots with CTE slots
Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
List<Expression> hashCond = pair.first.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
List<Expression> otherCond = pair.second.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
// TODO: normalize join condition
LogicalJoin<? extends Plan, ? extends Plan> newJoin = join.withJoinConjuncts(hashCond, otherCond)
.withChildren(Lists.newArrayList(left, right));
joins.add(newJoin);
}
return joins;
}
// join(a or b or c) = join(a) union join(b) union join(c)
// = join(a) union all (join b and !a) union all join(c and !b and !a)
// return hashConditions and otherConditions
private Pair<List<Expression>, List<Expression>> extractHashAndOtherConditions(int hashCondIdx,
List<Expression> equal, List<Expression> not) {
List<Expression> others = new ArrayList<>();
for (int i = 0; i < hashCondIdx; i++) {
others.add(not.get(i));
}
return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others);
}
}