diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index f3e9bf5c55..4ed887193f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules; import org.apache.doris.nereids.rules.exploration.MergeProjectsCBO; +import org.apache.doris.nereids.rules.exploration.OrExpansion; import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin; import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoinProject; import org.apache.doris.nereids.rules.exploration.join.InnerJoinLAsscom; @@ -98,6 +99,7 @@ public class RuleSet { public static final List EXPLORATION_RULES = planRuleFactories() .add(new MergeProjectsCBO()) + .add(new OrExpansion()) .build(); public static final List OTHER_REORDER_RULES = planRuleFactories() @@ -196,6 +198,7 @@ public class RuleSet { .build(); public static final List DPHYP_REORDER_RULES = ImmutableList.builder() + .addAll(EXPLORATION_RULES) .add(JoinCommute.BUSHY.build()) .build(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 114c0529fa..d43438ee85 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -250,6 +250,7 @@ public enum RuleType { // exploration rules TEST_EXPLORATION(RuleTypeClass.EXPLORATION), + OR_EXPANSION(RuleTypeClass.EXPLORATION), LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION), LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/OrExpansion.java new file mode 100644 index 0000000000..fcb46525fa --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/OrExpansion.java @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * https://blogs.oracle.com/optimizer/post/optimizer-transformations-or-expansion + * NLJ (cond1 or cond2) UnionAll + * => / \ + * HJ(cond1) HJ(cond2 and !cond1) + */ +public class OrExpansion extends OneExplorationRuleFactory { + + @Override + public Rule build() { + return logicalJoin() + .when(JoinUtils::shouldNestedLoopJoin) + .when(join -> join.getJoinType().isInnerJoin()) + .thenApply(ctx -> { + LogicalJoin join = ctx.root; + Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(), + "Only Expansion nest loop join without hashCond"); + List disjunctions = null; + List otherConditions = Lists.newArrayList(join.getOtherJoinConjuncts()); + + // We pick the first or condition that can be split to EqualTo expressions. + for (Expression expr : otherConditions) { + Pair, List> pair = expandExpr(expr, join); + // TODO: Now we don't support expand condition with complex expression + if (pair.second.isEmpty() && pair.first.stream() + .noneMatch(e -> !((EqualTo) e).left().isSlot() + && !((EqualTo) e).right().isSlot())) { + disjunctions = pair.first; + otherConditions.remove(expr); + break; + } + } + // If there is non-EqualTo expression, it means there is nlj child + // Therefore refuse this case + if (disjunctions == null) { + return join; + } + //Construct CTE with the children + LogicalCTEProducer leftProducer = new LogicalCTEProducer<>( + ctx.statementContext.getNextCTEId(), join.left()); + LogicalCTEProducer rightProducer = new LogicalCTEProducer<>( + ctx.statementContext.getNextCTEId(), join.right()); + // expand join to hash join with CTE + List joins = expandJoin(ctx.statementContext, disjunctions, otherConditions, join, + leftProducer, + rightProducer); + + LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()), + ImmutableList.of(), + false, joins); + LogicalCTEAnchor intermediateAnchor = new LogicalCTEAnchor<>( + rightProducer.getCteId(), rightProducer, union); + return new LogicalCTEAnchor(leftProducer.getCteId(), leftProducer, intermediateAnchor); + }).toRule(RuleType.OR_EXPANSION); + } + + // extract disjunctions for this otherExpr and divide them into HashCond and OtherCond + // return hash conditions and other conditions + private Pair, List> expandExpr(Expression otherExpr, + LogicalJoin join) { + List disjunctions = ExpressionUtils.extractDisjunction(otherExpr); + return JoinUtils.extractExpressionForHashTable(join.left().getOutput(), join.right().getOutput(), disjunctions); + } + + private List expandJoin(StatementContext ctx, List disjunctions, List otherConditions, + LogicalJoin join, LogicalCTEProducer leftProducer, + LogicalCTEProducer rightProducer) { + List notExprs = disjunctions.stream().map(Not::new).collect(Collectors.toList()); + List joins = Lists.newArrayList(); + + for (int hashCondIdx = 0; hashCondIdx < disjunctions.size(); hashCondIdx++) { + // extract hash conditions and other condition + Pair, List> pair = + extractHashAndOtherConditions(hashCondIdx, disjunctions, notExprs); + pair.second.addAll(otherConditions); + + LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getNextRelationId(), leftProducer.getCteId(), "", + leftProducer); + LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getNextRelationId(), rightProducer.getCteId(), "", + rightProducer); + + //rewrite conjuncts to replace the old slots with CTE slots + Map replaced = new HashMap<>(left.getProducerToConsumerOutputMap()); + replaced.putAll(right.getProducerToConsumerOutputMap()); + List hashCond = pair.first.stream() + .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)) + .collect(Collectors.toList()); + List otherCond = pair.second.stream() + .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)) + .collect(Collectors.toList()); + + // TODO: normalize join condition + LogicalJoin newJoin = join.withJoinConjuncts(hashCond, otherCond) + .withChildren(Lists.newArrayList(left, right)); + joins.add(newJoin); + } + return joins; + } + + // join(a or b or c) = join(a) union join(b) union join(c) + // = join(a) union all (join b and !a) union all join(c and !b and !a) + // return hashConditions and otherConditions + private Pair, List> extractHashAndOtherConditions(int hashCondIdx, + List equal, List not) { + List others = new ArrayList<>(); + for (int i = 0; i < hashCondIdx; i++) { + others.add(not.get(i)); + } + return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others); + } +} diff --git a/regression-test/data/nereids_p0/union/or_expansion.out b/regression-test/data/nereids_p0/union/or_expansion.out new file mode 100644 index 0000000000..a5fe6f155e --- /dev/null +++ b/regression-test/data/nereids_p0/union/or_expansion.out @@ -0,0 +1,9 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !nsj -- +false 1 1989 1001 11011902 123.123 true 1989-03-21 1989-03-21T13:00 wangjuoo4 0.1 6.333 string12345 170141183460469231731687303715884105727 false 1 1989 1001 11011902 123.123 true 1989-03-21 1989-03-21T13:00 wangjuoo4 0.1 6.333 string12345 170141183460469231731687303715884105727 +false 1 1989 1001 11011902 123.123 true 1989-03-21 1989-03-21T13:00 wangjuoo4 0.1 6.333 string12345 170141183460469231731687303715884105727 false 2 1986 1001 11011903 1243.500 false 1901-12-31 1989-03-21T13:00 wangynnsf 20.268 789.25 string12345 -170141183460469231731687303715884105727 +false 2 1986 1001 11011903 1243.500 false 1901-12-31 1989-03-21T13:00 wangynnsf 20.268 789.25 string12345 -170141183460469231731687303715884105727 false 1 1989 1001 11011902 123.123 true 1989-03-21 1989-03-21T13:00 wangjuoo4 0.1 6.333 string12345 170141183460469231731687303715884105727 +false 2 1986 1001 11011903 1243.500 false 1901-12-31 1989-03-21T13:00 wangynnsf 20.268 789.25 string12345 -170141183460469231731687303715884105727 false 2 1986 1001 11011903 1243.500 false 1901-12-31 1989-03-21T13:00 wangynnsf 20.268 789.25 string12345 -170141183460469231731687303715884105727 +false 3 1989 1002 11011905 24453.325 false 2012-03-14 2000-01-01T00:00 yunlj8@nk 78945.0 3654.0 string12345 0 false 3 1989 1002 11011905 24453.325 false 2012-03-14 2000-01-01T00:00 yunlj8@nk 78945.0 3654.0 string12345 0 +false 3 1989 1002 11011905 24453.325 false 2012-03-14 2000-01-01T00:00 yunlj8@nk 78945.0 3654.0 string12345 0 false 7 -32767 1002 7210457 3.141 false 1988-03-21 1901-01-01T00:00 jiw3n4 0.0 6058.0 string12345 -20220101 + diff --git a/regression-test/suites/nereids_p0/union/or_expansion.groovy b/regression-test/suites/nereids_p0/union/or_expansion.groovy new file mode 100644 index 0000000000..6528719667 --- /dev/null +++ b/regression-test/suites/nereids_p0/union/or_expansion.groovy @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("or_expansion") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + def db = "nereids_test_query_db" + sql "use ${db}" + + explain { + sql("""select * from bigtable + join baseall + on baseall.k0 = bigtable.k0 + or baseall.k1 = bigtable.k1""") + contains "VHASH JOIN" + } + + order_qt_nsj """select * from test + join baseall + on baseall.k1 = test.k1 + or baseall.k3 = test.k3 + """ +}