[feature](Nereids): enable OuterJoinAssoc (#19111)
This commit is contained in:
@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.exploration.join.JoinExchange;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinExchangeBothProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc;
|
||||
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
|
||||
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin;
|
||||
@ -162,6 +163,7 @@ public class RuleSet {
|
||||
.add(InnerJoinRightAssociateProject.INSTANCE)
|
||||
.add(JoinExchange.INSTANCE)
|
||||
.add(JoinExchangeBothProject.INSTANCE)
|
||||
.add(OuterJoinAssoc.INSTANCE)
|
||||
.build();
|
||||
|
||||
public List<Rule> getOtherReorderRules() {
|
||||
|
||||
@ -21,12 +21,14 @@ import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -58,17 +60,29 @@ public class OuterJoinAssoc extends OneExplorationRuleFactory {
|
||||
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left()))
|
||||
.when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet()))
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().isMarkJoin())
|
||||
.then(topJoin -> {
|
||||
.thenApply(ctx -> {
|
||||
LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin = ctx.root;
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left();
|
||||
GroupPlan a = bottomJoin.left();
|
||||
GroupPlan b = bottomJoin.right();
|
||||
GroupPlan c = topJoin.right();
|
||||
|
||||
/* TODO:
|
||||
* p23 need to reject nulls on A(e2) (Eqv. 1)
|
||||
* see paper `On the Correct and Complete Enumeration of the Core Search Space`.
|
||||
* But because we have added eliminate_outer_rule, we don't need to consider this.
|
||||
/*
|
||||
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
|
||||
* p23 need to reject nulls on A(e2) (Eqv. 1).
|
||||
* It means that when slot is null, condition must return false or unknown.
|
||||
*/
|
||||
if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) {
|
||||
Set<Slot> conditionSlot = topJoin.getConditionSlot();
|
||||
Set<Expression> on = ImmutableSet.<Expression>builder()
|
||||
.addAll(topJoin.getHashJoinConjuncts())
|
||||
.addAll(topJoin.getOtherJoinConjuncts()).build();
|
||||
Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(on,
|
||||
ctx.cascadesContext);
|
||||
if (!conditionSlot.equals(notNullSlots)) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c);
|
||||
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
|
||||
|
||||
@ -23,13 +23,18 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.CBOUtils;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -60,7 +65,8 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet()))
|
||||
.when(join -> join.left().isAllSlots())
|
||||
.then(topJoin -> {
|
||||
.thenApply(ctx -> {
|
||||
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin = ctx.root;
|
||||
/* ********** init ********** */
|
||||
List<NamedExpression> projects = topJoin.left().getProjects();
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
|
||||
@ -69,6 +75,23 @@ public class OuterJoinAssocProject extends OneExplorationRuleFactory {
|
||||
GroupPlan c = topJoin.right();
|
||||
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();
|
||||
|
||||
/*
|
||||
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
|
||||
* p23 need to reject nulls on A(e2) (Eqv. 1).
|
||||
* It means that when slot is null, condition must return false or unknown.
|
||||
*/
|
||||
if (bottomJoin.getJoinType().isLeftOuterJoin() && topJoin.getJoinType().isLeftOuterJoin()) {
|
||||
Set<Slot> conditionSlot = topJoin.getConditionSlot();
|
||||
Set<Expression> on = ImmutableSet.<Expression>builder()
|
||||
.addAll(topJoin.getHashJoinConjuncts())
|
||||
.addAll(topJoin.getOtherJoinConjuncts()).build();
|
||||
Set<Slot> notNullSlots = ExpressionUtils.inferNotNullSlots(on,
|
||||
ctx.cascadesContext);
|
||||
if (!conditionSlot.equals(notNullSlots)) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/* ********** Split projects ********** */
|
||||
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
|
||||
List<NamedExpression> aProjects = map.get(true);
|
||||
|
||||
@ -33,7 +33,6 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
@ -85,7 +84,7 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i
|
||||
Set<NamedExpression> normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows);
|
||||
// only need normalized windowExpressions
|
||||
LogicalWindow normalizedLogicalWindow =
|
||||
new LogicalWindow(Lists.newArrayList(normalizedWindowWithAlias), normalizedChild);
|
||||
new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild);
|
||||
|
||||
// 3. handle top projects
|
||||
List<NamedExpression> topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs1);
|
||||
|
||||
@ -56,35 +56,38 @@ import java.util.stream.Stream;
|
||||
* Alias(SUM(v1#3 + 1))#7, Alias(SUM(v1#3) + 1)#8])
|
||||
* </pre>
|
||||
* After rule:
|
||||
* <pre>
|
||||
* Project(k1#1, Alias(SR#9)#4, Alias(k1#1 + 1)#5, Alias(SR#10))#6, Alias(SR#11))#7, Alias(SR#10 + 1)#8)
|
||||
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, Alias(SUM(v1#3 + 1))#11])
|
||||
* +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
|
||||
* <p>
|
||||
*
|
||||
* </pre>
|
||||
* Note: window function will be moved to upper project
|
||||
* all agg functions except the top agg should be pushed to Aggregate node.
|
||||
* example 1:
|
||||
* <pre>
|
||||
* select min(x), sum(x) over () ...
|
||||
* the 'sum(x)' is top agg of window function, it should be moved to upper project
|
||||
* plan:
|
||||
* project(sum(x) over())
|
||||
* Aggregate(min(x), x)
|
||||
*
|
||||
* </pre>
|
||||
* example 2:
|
||||
* <pre>
|
||||
* select min(x), avg(sum(x)) over() ...
|
||||
* the 'sum(x)' should be moved to Aggregate
|
||||
* plan:
|
||||
* project(avg(y) over())
|
||||
* Aggregate(min(x), sum(x) as y)
|
||||
* </pre>
|
||||
* example 3:
|
||||
* <pre>
|
||||
* select sum(x+1), x+1, sum(x+1) over() ...
|
||||
* window function should use x instead of x+1
|
||||
* plan:
|
||||
* project(sum(x+1) over())
|
||||
* Agg(sum(y), x)
|
||||
* project(x+1 as y)
|
||||
*
|
||||
*
|
||||
* </pre>
|
||||
* More example could get from UT {NormalizeAggregateTest}
|
||||
*/
|
||||
public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot {
|
||||
@ -138,8 +141,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
// some expression on the aggregate functions, e.g. `sum(value) + 1`, we should replace
|
||||
// the sum(value) to slot and move the `slot + 1` to the upper project later.
|
||||
List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
|
||||
aggregate.getOutputExpressions().stream(),
|
||||
aliasOfAggFunInWindowUsedAsAggOutput.stream())
|
||||
aggregate.getOutputExpressions().stream(),
|
||||
aliasOfAggFunInWindowUsedAsAggOutput.stream())
|
||||
.map(expr -> groupByAndArgumentToSlotContext
|
||||
.normalizeToUseSlotRefUp(expr, WindowExpression.class::isInstance))
|
||||
.collect(Collectors.toList());
|
||||
@ -198,19 +201,14 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
|
||||
Set<AggregateFunction> aggregateFunctions = collectNonWindowedAggregateFunctions(
|
||||
aggregate.getOutputExpressions());
|
||||
|
||||
ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
|
||||
.flatMap(function -> function.getArguments().stream().map(arg -> {
|
||||
if (arg instanceof OrderExpression) {
|
||||
return arg.child(0);
|
||||
} else {
|
||||
return arg;
|
||||
}
|
||||
}))
|
||||
Set<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
|
||||
.flatMap(function -> function.getArguments().stream()
|
||||
.map(expr -> expr instanceof OrderExpression ? expr.child(0) : expr))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
|
||||
Set<Expression> windowFunctionKeys = collectWindowFunctionKeys(aggregate.getOutputExpressions());
|
||||
|
||||
ImmutableSet<Expression> needPushDown = ImmutableSet.<Expression>builder()
|
||||
Set<Expression> needPushDown = ImmutableSet.<Expression>builder()
|
||||
// group by should be pushed down, e.g. group by (k + 1),
|
||||
// we should push down the `k + 1` to the bottom plan
|
||||
.addAll(groupingByExpr)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
@ -28,6 +29,8 @@ import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Objects;
|
||||
@ -78,4 +81,17 @@ class OuterJoinAssocTest implements MemoPatternMatchSupported {
|
||||
).when(top -> Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]"))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void rejectNull() {
|
||||
IsNull isNull = new IsNull(scan3.getOutput().get(0));
|
||||
LogicalPlan join = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.join(scan3, JoinType.LEFT_OUTER_JOIN, ImmutableList.of(), ImmutableList.of(isNull)) // t3.id is not null
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
|
||||
.applyExploration(OuterJoinAssoc.INSTANCE.build())
|
||||
.checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user