[feature](Nereids): enable OuterJoinAssoc (#19111)

This commit is contained in:
jakevin
2023-04-27 09:50:15 +08:00
committed by GitHub
parent a262f42a28
commit 32fa9e09f4
6 changed files with 76 additions and 24 deletions

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.exploration.join.JoinExchange;
import org.apache.doris.nereids.rules.exploration.join.JoinExchangeBothProject;
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin;
@ -162,6 +163,7 @@ public class RuleSet {
.add(InnerJoinRightAssociateProject.INSTANCE)
.add(JoinExchange.INSTANCE)
.add(JoinExchangeBothProject.INSTANCE)
.add(OuterJoinAssoc.INSTANCE)
.build();
public List<Rule> getOtherReorderRules() {

View File

@ -21,12 +21,14 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableSet;
@ -58,17 +60,29 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory {
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left()))
.when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet()))
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
.then(topJoin -> {
.thenApply(ctx -> {
LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin = ctx.root;
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
/* TODO:
* p23 need to reject nulls on A(e2) (Eqv. 1)
* see paper `On the Correct and Complete Enumeration of the Core Search Space`.
* But because we have added eliminate_outer_rule, we don't need to consider this.
/*
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
* p23 need to reject nulls on A(e2) (Eqv. 1).
* It means that when slot is null, condition must return false or unknown.
*/
if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) {
Set<Slot> conditionSlot = topJoin.getConditionSlot();
Set<Expression> on = ImmutableSet.<Expression>builder()
.addAll(topJoin.getHashJoinConjuncts())
.addAll(topJoin.getOtherJoinConjuncts()).build();
Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(on,
ctx.cascadesContext);
if (!conditionSlot.equals(notNullSlots)) {
return null;
}
}
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());

View File

@ -23,13 +23,18 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -60,7 +65,8 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
.thenApply(ctx -> {
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin = ctx.root;
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
@ -69,6 +75,23 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
GroupPlan c = topJoin.right();
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();
/*
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
* p23 need to reject nulls on A(e2) (Eqv. 1).
* It means that when slot is null, condition must return false or unknown.
*/
if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) {
Set<Slot> conditionSlot = topJoin.getConditionSlot();
Set<Expression> on = ImmutableSet.<Expression>builder()
.addAll(topJoin.getHashJoinConjuncts())
.addAll(topJoin.getOtherJoinConjuncts()).build();
Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(on,
ctx.cascadesContext);
if (!conditionSlot.equals(notNullSlots)) {
return null;
}
}
/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
List<NamedExpression> aProjects = map.get(true);

View File

@ -33,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
@ -85,7 +84,7 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
Set<NamedExpression> normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows);
// only need normalized windowExpressions
LogicalWindow normalizedLogicalWindow =
new LogicalWindow(Lists.newArrayList(normalizedWindowWithAlias), normalizedChild);
new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild);
// 3. handle top projects
List<NamedExpression> topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs1);

View File

@ -56,35 +56,38 @@ import java.util.stream.Stream;
* Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8])
* </pre>
* After rule:
* <pre>
* Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, Alias(SR#11))#7, Alias(SR#10 + 1)#8)
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11])
* +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
* <p>
*
* </pre>
* Note: window function will be moved to upper project
* all agg functions except the top agg should be pushed to Aggregate node.
* example 1:
* <pre>
* select min(x), sum(x) over () ...
* the 'sum(x)' is top agg of window function, it should be moved to upper project
* plan:
* project(sum(x) over())
* Aggregate(min(x), x)
*
* </pre>
* example 2:
* <pre>
* select min(x), avg(sum(x)) over() ...
* the 'sum(x)' should be moved to Aggregate
* plan:
* project(avg(y) over())
* Aggregate(min(x), sum(x) as y)
* </pre>
* example 3:
* <pre>
* select sum(x+1), x+1, sum(x+1) over() ...
* window function should use x instead of x+1
* plan:
* project(sum(x+1) over())
* Agg(sum(y), x)
* project(x+1 as y)
*
*
* </pre>
* More example could get from UT {NormalizeAggregateTest}
*/
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
@ -138,8 +141,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
// some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace
// the sum(value) to slot and move the `slot + 1` to the upper project later.
List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
aggregate.getOutputExpressions().stream(),
aliasOfAggFunInWindowUsedAsAggOutput.stream())
aggregate.getOutputExpressions().stream(),
aliasOfAggFunInWindowUsedAsAggOutput.stream())
.map(expr -> groupByAndArgumentToSlotContext
.normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance))
.collect(Collectors.toList());
@ -198,19 +201,14 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
Set<AggregateFunction> aggregateFunctions = collectNonWindowedAggregateFunctions(
aggregate.getOutputExpressions());
ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream().map(arg -> {
if (arg instanceof OrderExpression) {
return arg.child(0);
} else {
return arg;
}
}))
Set<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream()
.map(expr -> expr instanceof OrderExpression ? expr.child(0) : expr))
.collect(ImmutableSet.toImmutableSet());
Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions());
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
Set<Expression> needPushDown = ImmutableSet.<Expression>builder()
// group by should be pushed down, e.g. group by (k + 1),
// we should push down the `k + 1` to the bottom plan
.addAll(groupingByExpr)

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
@ -28,6 +29,8 @@ import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Objects;
@ -78,4 +81,17 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
).when(top -> Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]"))
);
}
@Test
public void rejectNull() {
IsNull isNull = new IsNull(scan3.getOutput().get(0));
LogicalPlan join = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.applyExploration(OuterJoinAssoc.INSTANCE.build())
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
}
}