diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index aa30417fc2..b696b94f59 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.exploration.join.JoinExchange; import org.apache.doris.nereids.rules.exploration.join.JoinExchangeBothProject; import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose; import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject; +import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc; import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom; import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject; import org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin; @@ -162,6 +163,7 @@ public class RuleSet { .add(InnerJoinRightAssociateProject.INSTANCE) .add(JoinExchange.INSTANCE) .add(JoinExchangeBothProject.INSTANCE) + .add(OuterJoinAssoc.INSTANCE) .build(); public List getOtherReorderRules() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java index 6bb35baa88..2080cfce93 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java @@ -21,12 +21,14 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableSet; @@ -58,17 +60,29 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory { .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left())) .when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet())) .whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin()) - .then(topJoin -> { + .thenApply(ctx -> { + LogicalJoin, GroupPlan> topJoin = ctx.root; LogicalJoin bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); - /* TODO: - * p23 need to reject nulls on A(e2) (Eqv. 1) - * see paper `On the Correct and Complete Enumeration of the Core Search Space`. - * But because we have added eliminate_outer_rule, we don't need to consider this. + /* + * Paper `On the Correct and Complete Enumeration of the Core Search Space`. + * p23 need to reject nulls on A(e2) (Eqv. 1). + * It means that when slot is null, condition must return false or unknown. */ + if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) { + Set conditionSlot = topJoin.getConditionSlot(); + Set on = ImmutableSet.builder() + .addAll(topJoin.getHashJoinConjuncts()) + .addAll(topJoin.getOtherJoinConjuncts()).build(); + Set notNullSlots = ExpressionUtils.inferNotNullSlots(on, + ctx.cascadesContext); + if (!conditionSlot.equals(notNullSlots)) { + return null; + } + } LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c); newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java index 308e1db583..efd9e14faf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java @@ -23,13 +23,18 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.CBOUtils; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; +import com.google.common.collect.ImmutableSet; + import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -60,7 +65,8 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { .whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin()) .when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet())) .when(join -> join.left().isAllSlots()) - .then(topJoin -> { + .thenApply(ctx -> { + LogicalJoin>, GroupPlan> topJoin = ctx.root; /* ********** init ********** */ List projects = topJoin.left().getProjects(); LogicalJoin bottomJoin = topJoin.left().child(); @@ -69,6 +75,23 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory { GroupPlan c = topJoin.right(); Set aOutputExprIds = a.getOutputExprIdSet(); + /* + * Paper `On the Correct and Complete Enumeration of the Core Search Space`. + * p23 need to reject nulls on A(e2) (Eqv. 1). + * It means that when slot is null, condition must return false or unknown. + */ + if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) { + Set conditionSlot = topJoin.getConditionSlot(); + Set on = ImmutableSet.builder() + .addAll(topJoin.getHashJoinConjuncts()) + .addAll(topJoin.getOtherJoinConjuncts()).build(); + Set notNullSlots = ExpressionUtils.inferNotNullSlots(on, + ctx.cascadesContext); + if (!conditionSlot.equals(notNullSlots)) { + return null; + } + } + /* ********** Split projects ********** */ Map> map = CBOUtils.splitProject(projects, aOutputExprIds); List aProjects = map.get(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java index 718d6acc6e..9282ef3825 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java @@ -33,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.List; @@ -85,7 +84,7 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i Set normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows); // only need normalized windowExpressions LogicalWindow normalizedLogicalWindow = - new LogicalWindow(Lists.newArrayList(normalizedWindowWithAlias), normalizedChild); + new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild); // 3. handle top projects List topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs1); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index cf802245ca..7938cf8769 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -56,35 +56,38 @@ import java.util.stream.Stream; * Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8]) * * After rule: + *
  * Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, Alias(SR#11))#7, Alias(SR#10 + 1)#8)
  * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11])
  *   +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
- * 

- * + *

* Note: window function will be moved to upper project * all agg functions except the top agg should be pushed to Aggregate node. * example 1: + *
  *    select min(x), sum(x) over () ...
  * the 'sum(x)' is top agg of window function, it should be moved to upper project
  * plan:
  *    project(sum(x) over())
  *        Aggregate(min(x), x)
- *
+ * 
* example 2: + *
  *    select min(x), avg(sum(x)) over() ...
  * the 'sum(x)' should be moved to Aggregate
  * plan:
  *    project(avg(y) over())
  *         Aggregate(min(x), sum(x) as y)
+ * 
* example 3: + *
  *    select sum(x+1), x+1, sum(x+1) over() ...
  * window function should use x instead of x+1
  * plan:
  *    project(sum(x+1) over())
  *        Agg(sum(y), x)
  *            project(x+1 as y)
- *
- *
+ * 
* More example could get from UT {NormalizeAggregateTest} */ public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot { @@ -138,8 +141,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali // some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace // the sum(value) to slot and move the `slot + 1` to the upper project later. List normalizeOutputPhase1 = Stream.concat( - aggregate.getOutputExpressions().stream(), - aliasOfAggFunInWindowUsedAsAggOutput.stream()) + aggregate.getOutputExpressions().stream(), + aliasOfAggFunInWindowUsedAsAggOutput.stream()) .map(expr -> groupByAndArgumentToSlotContext .normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance)) .collect(Collectors.toList()); @@ -198,19 +201,14 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali Set aggregateFunctions = collectNonWindowedAggregateFunctions( aggregate.getOutputExpressions()); - ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() - .flatMap(function -> function.getArguments().stream().map(arg -> { - if (arg instanceof OrderExpression) { - return arg.child(0); - } else { - return arg; - } - })) + Set argumentsOfAggregateFunction = aggregateFunctions.stream() + .flatMap(function -> function.getArguments().stream() + .map(expr -> expr instanceof OrderExpression ? expr.child(0) : expr)) .collect(ImmutableSet.toImmutableSet()); Set windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions()); - ImmutableSet needPushDown = ImmutableSet.builder() + Set needPushDown = ImmutableSet.builder() // group by should be pushed down, e.g. group by (k + 1), // we should push down the `k + 1` to the bottom plan .addAll(groupingByExpr) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java index ad98b7650f..c3beb8fc11 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -28,6 +29,8 @@ import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Objects; @@ -78,4 +81,17 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported { ).when(top -> Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]")) ); } + + @Test + public void rejectNull() { + IsNull isNull = new IsNull(scan3.getOutput().get(0)); + LogicalPlan join = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyExploration(OuterJoinAssoc.INSTANCE.build()) + .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); + } }